{ "cells": [ { "cell_type": "markdown", "id": "5fbc2d16-59f9-4be3-b93e-1a5440c7efd0", "metadata": {}, "source": [ "# Tutorial 8 - Adaptive PIKANs" ] }, { "cell_type": "markdown", "id": "1afe6a1e-3ab4-4f3f-ad47-f6cd66419504", "metadata": {}, "source": [ "In Tutorial 6 we got a taste of solving PDEs using KANs and in Tutorial 7 we started exploring adaptive training techniques. Building on this **adaptive training** idea, in this tutorial we will see how to adaptively train PIKANs, based on the findings of [this](https://ieeexplore.ieee.org/document/10763509) and [this](https://www.sciencedirect.com/science/article/pii/S0045782526000356) paper." ] }, { "cell_type": "code", "execution_count": 1, "id": "0a2ef2a6-f681-427f-8252-ade2111ce0e6", "metadata": {}, "outputs": [], "source": [ "from jaxkan.models.KAN import KAN\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "from jaxkan.pikan.pde import get_ac_res\n", "from jaxkan.pikan.sampling import get_collocs_grid\n", "from jaxkan.pikan.adaptive import (\n", " apply_rba_weights,\n", " get_causal_weights,\n", " get_colloc_indices,\n", " get_rad_indices,\n", " lr_anneal,\n", " update_rba_weights,\n", ")\n", "\n", "from typing import Union, List\n", "\n", "from flax import nnx\n", "import optax\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import os\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" ] }, { "cell_type": "code", "execution_count": null, "id": "4144965f-84cf-4d9d-a95d-e98b4c9e6543", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "60e70a5d-0340-4bdc-af4e-150d0098a87e", "metadata": {}, "source": [ "## Data Generation" ] }, { "cell_type": "markdown", "id": "166ad90d-430e-45ca-86a8-e6e4dbc1c943", "metadata": {}, "source": [ "Burgers' Equation was relatively easy to solve even without adaptive techniques, so in this case we will be solving the Allen-Cahn Equation,\n", "\n", "$$ \\frac{\\partial u}{\\partial t} - D\\frac{\\partial^2 u}{\\partial x^2} + 5 \\left(u^3 - u\\right) = 0,$$\n", "\n", "for $D = 10^{-4}$ in the $\\Omega = [0,1]\\times [-1, 1]$ domain, subject to the boundary conditions\n", "\n", "$$ u\\left(t=0, x\\right) = x^2 \\cos\\left(\\pi x\\right), $$\n", "\n", "$$ u\\left(t, x=-1\\right) = u\\left(t, x=1\\right) = -1. $$\n", "\n", "In the following, we must first define the corresponding collocation points. This time we will be creating a large pool of collocation points, from which we will be sampling batches by performing RAD resampling (see the two referenced papers for more information)." ] }, { "cell_type": "code", "execution_count": 2, "id": "b986e75a-6d4a-402f-bea7-36d13f4a7866", "metadata": {}, "outputs": [], "source": [ "seed = 42\n", "\n", "# Generate Collocation points for PDE\n", "collocs_pool = get_collocs_grid(ranges=[(0, 1, 2**7), (-1, 1, 2**7)])\n", "\n", "# Generate Collocation points for IC\n", "ic_collocs = get_collocs_grid(ranges=[(0, 0, 1), (-1, 1, 2**6)])\n", "ic_data = ((ic_collocs[:,1]**2)*jnp.cos(jnp.pi*ic_collocs[:,1])).reshape(-1,1)" ] }, { "cell_type": "code", "execution_count": null, "id": "0dbe48b2-a680-4043-a38c-fb18134229d1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "4e9d9bba-71e2-4d5b-95b1-36167fb70c1a", "metadata": {}, "source": [ "## Custom KAN Model" ] }, { "cell_type": "markdown", "id": "2fd54c6c-3d28-4864-af38-3b22875d0f0a", "metadata": {}, "source": [ "As seen above, we have not defined any collocation points for the boundary conditions. This is because we intend to use what we learned in \"DIY KANs\" to define our own custom KAN model, which will directly enforce the boundary conditions through its architecture (see [this](https://doi.org/10.1016/j.cma.2021.114333) paper for more details). In particular, we will define a wrapper class that inherits from KAN but adds an additional step in its forward pass." ] }, { "cell_type": "code", "execution_count": 3, "id": "09826dd0-7d2a-4390-88bd-70bed81c0581", "metadata": {}, "outputs": [], "source": [ "class KANWrapper(KAN):\n", "\n", " def __init__(self, layer_dims: List[int], layer_type: str = \"base\",\n", " required_parameters: Union[None, dict] = None, seed: int = 42\n", " ):\n", " \n", " self.model = KAN(layer_dims, layer_type, required_parameters, seed)\n", "\n", " \n", " def __call__(self, x):\n", " \n", " original_x = x\n", "\n", " y = self.model(x)\n", "\n", " # Impose BC u(t, -1) = u(t, 1) = -1\n", " x_coord = original_x[:, 1:2]\n", " # In this way, when x = -1 or when x = 1 the factor (1 - x_coord**2) nullifies the model's output and the -1 term leads to u = -1, as required\n", " y = (1 - x_coord**2) * y - 1.0\n", "\n", " return y" ] }, { "cell_type": "code", "execution_count": 4, "id": "b350b56f-5daa-411f-9090-b544760c34ef", "metadata": {}, "outputs": [], "source": [ "# Initialize a KAN model\n", "n_in = collocs_pool.shape[1]\n", "n_out = 1\n", "n_hidden = 12\n", "\n", "layer_dims = [n_in, n_hidden, n_hidden, n_hidden, n_out]\n", "req_params = {'D': 5, 'flavor': 'exact', 'residual': None, 'external_weights': False, 'init_scheme': {'type': 'glorot_fine'}, 'add_bias': True}\n", "\n", "model = KANWrapper(layer_dims = layer_dims,\n", " layer_type = 'chebyshev',\n", " required_parameters = req_params,\n", " seed = seed\n", " )\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "0051714b-2f23-45c9-994f-90a6a0705da2", "metadata": {}, "outputs": [], "source": [ "# We will also be using a more adaptive optimizer with learning rate scheduling\n", "lr_schedule = optax.exponential_decay(\n", " init_value=1e-3,\n", " transition_steps=1000,\n", " decay_rate=0.9,\n", " staircase=False\n", " )\n", "\n", "opt_type = optax.adam(learning_rate=lr_schedule, b1=0.9, b2=0.999, eps=1e-8)\n", "\n", "optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)" ] }, { "cell_type": "code", "execution_count": null, "id": "db825176-cba8-43e4-9488-dadd2d3b6672", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "f681d135-c885-4e59-87d7-1ea16a157c50", "metadata": {}, "source": [ "## Adaptive Training" ] }, { "cell_type": "markdown", "id": "eaac2ddc-ad53-40db-804b-7f54f3000476", "metadata": {}, "source": [ "Unlike in Tutorial 6, where we defined a simple MSE loss to train the network, here we will be using some additional adaptive training methods: RBA, RAD, learning rate annealing and causal training (for an in-depth look, read [Training Deep Physics-Informed Kolmogorov–Arnold Networks](https://www.sciencedirect.com/science/article/pii/S0045782526000356))." ] }, { "cell_type": "code", "execution_count": 6, "id": "4d7e1c78-70b4-4b0f-9b8b-cb516465405e", "metadata": {}, "outputs": [], "source": [ "@nnx.jit\n", "def get_rad_collocs(model, pde_collocs_pool, sorted_indices, l_pde, l_pde_pool):\n", " resids_pool = pde_res(model, pde_collocs_pool)\n", " new_indices, new_pool, _ = get_rad_indices(\n", " collocs_pool=pde_collocs_pool,\n", " residuals=resids_pool,\n", " old_indices=sorted_indices,\n", " batch_weights=l_pde,\n", " pool_weights=l_pde_pool,\n", " batch_size=batch_size,\n", " rad_a=rad_a,\n", " rad_c=rad_c,\n", " seed=seed,\n", " )\n", " return new_indices, new_pool" ] }, { "cell_type": "code", "execution_count": 7, "id": "65d77457-e41c-4e66-a15b-2b4f709eb022", "metadata": {}, "outputs": [], "source": [ "# PDE Residual\n", "pde_res = get_ac_res()\n", "\n", "# PDE Loss\n", "def pde_loss(model, l_E, collocs):\n", "\n", " residuals = pde_res(model, collocs) # shape (batch_size, 1)\n", "\n", " # Get new RBA weights\n", " l_E_new = update_rba_weights(residuals, l_E, gamma=RBA_gamma, eta=RBA_eta)\n", "\n", " # Multiply by RBA weights while keeping them out of the backward graph\n", " w_resids = apply_rba_weights(residuals, l_E_new)\n", "\n", " # Reshape residuals for causal training\n", " residuals = w_resids.reshape(num_chunks, -1) # shape (num_chunks, points)\n", "\n", " # Get average loss per chunk\n", " loss = jnp.mean(residuals**2, axis=1)\n", "\n", " # Get causal weights\n", " weights = get_causal_weights(loss, M, causal_tol)\n", "\n", " # Weighted loss\n", " weighted_loss = jnp.mean(weights * loss)\n", "\n", " return weighted_loss, l_E_new\n", "\n", "\n", "def ic_loss(model, l_I, ic_collocs, ic_data):\n", "\n", " # Residual\n", " ic_res = model(ic_collocs) - ic_data\n", "\n", " # Get new RBA weights\n", " l_I_new = update_rba_weights(ic_res, l_I, gamma=RBA_gamma, eta=RBA_eta)\n", "\n", " # Multiply by RBA weights while keeping them out of the backward graph\n", " w_resids = apply_rba_weights(ic_res, l_I_new)\n", "\n", " # Loss\n", " loss = jnp.mean(w_resids**2)\n", "\n", " return loss, l_I_new\n", "\n", "\n", "@nnx.jit(static_argnames=(\"compute_grads_sep\",))\n", "def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I, compute_grads_sep=False):\n", "\n", " def total_loss_fn(model):\n", " (loss_E, l_E_new) = pde_loss(model, l_E, collocs)\n", " (loss_I, l_I_new) = ic_loss(model, l_I, ic_collocs, ic_data)\n", " total = λ_E * loss_E + λ_I * loss_I\n", " return total, (loss_E, loss_I, l_E_new, l_I_new)\n", "\n", " (loss, aux), grads = nnx.value_and_grad(total_loss_fn, has_aux=True)(model)\n", "\n", " sep_grads = (None, None)\n", " if compute_grads_sep:\n", " grads_E = nnx.grad(lambda m: pde_loss(m, l_E, collocs)[0])(model)\n", " grads_I = nnx.grad(lambda m: ic_loss(m, l_I, ic_collocs, ic_data)[0])(model)\n", " sep_grads = (grads_E, grads_I)\n", " \n", " optimizer.update(model, grads)\n", " \n", " return loss, aux, sep_grads" ] }, { "cell_type": "code", "execution_count": 8, "id": "3196e686-4998-409b-94d2-c3dd3de3b98b", "metadata": {}, "outputs": [], "source": [ "num_epochs = 20_000\n", "\n", "# Define causal training parameters\n", "causal_tol = 1.0\n", "num_chunks = 32\n", "M = jnp.triu(jnp.ones((num_chunks, num_chunks)), k=1).T\n", "\n", "# Define LR Annealing parameters\n", "grad_mixing = 0.9\n", "f_grad_norm = 1000\n", "\n", "# Define resampling parameters\n", "batch_size = 2**12\n", "f_resample = 2000\n", "rad_a = 1.0\n", "rad_c = 1.0\n", "\n", "# Define RBA parameters\n", "RBA_gamma = 0.999\n", "RBA_eta = 0.01" ] }, { "cell_type": "code", "execution_count": 9, "id": "06e24b36-d5b3-4b75-9145-39aaf5743180", "metadata": {}, "outputs": [], "source": [ "# Initialize RBA weights - full pool\n", "l_E_pool = jnp.ones((collocs_pool.shape[0], 1))\n", "# Also get RBAs for ICs\n", "l_I = jnp.ones((ic_collocs.shape[0], 1))\n", "\n", "# Get starting collocation points & RBA weights\n", "sorted_indices = get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)\n", "\n", "pde_collocs = collocs_pool[sorted_indices]\n", "l_E = l_E_pool[sorted_indices]\n", "\n", "# Define global loss weights (initialization)\n", "λ_E = jnp.array(1.0, dtype=float)\n", "λ_I = jnp.array(1.0, dtype=float)" ] }, { "cell_type": "markdown", "id": "3d54a2ff-0795-40bd-92c9-32f28ad25d6e", "metadata": {}, "source": [ "Following this setup, we proceed to train the model." ] }, { "cell_type": "code", "execution_count": 10, "id": "c5459ff8-5e00-45b4-9e2a-95f1cd59cf7f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch No. 1000. Current loss: 1.80e-02. Performing learning-rate annealing.\n", "Epoch No. 2000. Current loss: 3.39e-02. Performing learning-rate annealing.\n", "Epoch No. 2000. Current loss: 3.39e-02. Performing RAD resampling.\n", "Epoch No. 3000. Current loss: 4.19e-02. Performing learning-rate annealing.\n", "Epoch No. 4000. Current loss: 4.89e-02. Performing learning-rate annealing.\n", "Epoch No. 4000. Current loss: 4.89e-02. Performing RAD resampling.\n", "Epoch No. 5000. Current loss: 4.17e-02. Performing learning-rate annealing.\n", "Epoch No. 6000. Current loss: 4.21e-03. Performing learning-rate annealing.\n", "Epoch No. 6000. Current loss: 4.21e-03. Performing RAD resampling.\n", "Epoch No. 7000. Current loss: 3.18e-03. Performing learning-rate annealing.\n", "Epoch No. 8000. Current loss: 3.07e-03. Performing learning-rate annealing.\n", "Epoch No. 8000. Current loss: 3.07e-03. Performing RAD resampling.\n", "Epoch No. 9000. Current loss: 2.66e-03. Performing learning-rate annealing.\n", "Epoch No. 10000. Current loss: 2.53e-03. Performing learning-rate annealing.\n", "Epoch No. 10000. Current loss: 2.53e-03. Performing RAD resampling.\n", "Epoch No. 11000. Current loss: 2.85e-03. Performing learning-rate annealing.\n", "Epoch No. 12000. Current loss: 2.42e-03. Performing learning-rate annealing.\n", "Epoch No. 12000. Current loss: 2.42e-03. Performing RAD resampling.\n", "Epoch No. 13000. Current loss: 2.58e-03. Performing learning-rate annealing.\n", "Epoch No. 14000. Current loss: 2.26e-03. Performing learning-rate annealing.\n", "Epoch No. 14000. Current loss: 2.26e-03. Performing RAD resampling.\n", "Epoch No. 15000. Current loss: 1.80e-03. Performing learning-rate annealing.\n", "Epoch No. 16000. Current loss: 1.40e-03. Performing learning-rate annealing.\n", "Epoch No. 16000. Current loss: 1.40e-03. Performing RAD resampling.\n", "Epoch No. 17000. Current loss: 1.26e-03. Performing learning-rate annealing.\n", "Epoch No. 18000. Current loss: 1.15e-03. Performing learning-rate annealing.\n", "Epoch No. 18000. Current loss: 1.15e-03. Performing RAD resampling.\n", "Epoch No. 19000. Current loss: 1.02e-03. Performing learning-rate annealing.\n" ] } ], "source": [ "train_losses = jnp.zeros((num_epochs,))\n", "\n", "# Start training\n", "for epoch in range(num_epochs):\n", "\n", " do_anneal = (epoch != 0) and (epoch % f_grad_norm == 0)\n", "\n", " loss, aux, sep_grads = train_step(\n", " model, optimizer, pde_collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I,\n", " compute_grads_sep=do_anneal\n", " )\n", "\n", " loss_E, loss_I, l_E, l_I = aux\n", "\n", " # Perform lr annealing\n", " if do_anneal:\n", "\n", " print(f\"Epoch No. {epoch}. Current loss: {loss:.2e}. Performing learning-rate annealing.\")\n", "\n", " λ_E, λ_I = lr_anneal((sep_grads[0], sep_grads[1]), (λ_E, λ_I), grad_mixing)\n", "\n", " # Perform RAD\n", " if (epoch != 0) and (epoch % f_resample == 0):\n", "\n", " print(f\"Epoch No. {epoch}. Current loss: {loss:.2e}. Performing RAD resampling.\")\n", "\n", " sorted_indices, l_E_pool = get_rad_collocs(\n", " model, collocs_pool, sorted_indices, l_E, l_E_pool\n", " )\n", "\n", " # Set new batch of collocs and l_E\n", " pde_collocs = collocs_pool[sorted_indices]\n", " l_E = l_E_pool[sorted_indices]\n", "\n", " # Append the loss\n", " train_losses = train_losses.at[epoch].set(loss)" ] }, { "cell_type": "code", "execution_count": null, "id": "d48ede2f-f25f-40fd-b472-5339e417c039", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "faab949e-dadb-4cec-bc13-63b10cb609c5", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "markdown", "id": "0b482774-013c-4f1a-98dc-6e3a44171783", "metadata": {}, "source": [ "By visualizing the train loss curve, we indeed see how the adaptive training techniques implemented lead to a significantly small training error by the end of training." ] }, { "cell_type": "code", "execution_count": 11, "id": "8e82910c-d539-4ae1-adb4-c8d3caf597ce", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnEAAAGJCAYAAADlpGXRAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAnQZJREFUeJzsnXlc1VX+/58XkH2XTRREQRQFREQo2ydbnHbrl9nm0lRTOTXjNC3TjGXzbWqqcZrKpl3byxaradfKsVIRUEQRUBRwYRFk37n3fn5/0P3EhcvqXT6Xc56Phw/5fD7nfj7v1+cczn1zznm/j05RFAWJRCKRSCQSiVPh4mgDJBKJRCKRSCTDRzpxEolEIpFIJE6IdOIkEolEIpFInBDpxEkkEolEIpE4IdKJk0gkEolEInFCpBMnkUgkEolE4oRIJ04ikUgkEonECZFOnEQikUgkEokTIp04iUQikUgkEidEOnESicSMJUuWEBMTM6LPPvTQQ+h0OusaJJEMgqnd1dTUONoUicSuSCdOInESdDrdkP5t3rzZ0aY6hCVLluDr6+toM4aEoii88cYbnHnmmQQGBuLt7U1SUhIPP/wwLS0tjjavDyYnqb9/lZWVjjZRIhESN0cbIJFIhsYbb7xhdvz666+zcePGPucTEhJO6jkvvfQSRqNxRJ/9y1/+wn333XdSzx/tGAwGrr32WtavX88ZZ5zBQw89hLe3Nz/88AOrVq3i/fffZ9OmTYSHhzva1D785z//segoBwYG2t8YiUQinTiJxFm4/vrrzY63b9/Oxo0b+5zvTWtrK97e3kN+zpgxY0ZkH4CbmxtubrJbGYjHH3+c9evXc/fdd/PEE0+o52+55RauvvpqLr/8cpYsWcKXX35pV7uG0k6uuuoqQkJC7GSRRCIZDDmdKpGMIs4++2wSExPJycnhzDPPxNvbmz//+c8AfPLJJ1x00UVERkbi4eFBbGwsf/vb3zAYDGb36L0mrrS0FJ1Ox5NPPsmLL75IbGwsHh4ezJkzh6ysLLPPWloTp9PpWL58OR9//DGJiYl4eHgwY8YMvvrqqz72b968mbS0NDw9PYmNjeWFF16w+jq7999/n9mzZ+Pl5UVISAjXX389x44dMytTWVnJ0qVLmTBhAh4eHowbN47LLruM0tJStUx2djYXXHABISEheHl5MWnSJJYtWzbgs9va2njiiSeIj4/n0Ucf7XP9kksuYfHixXz11Vds374dgIsvvpjJkydbvN+pp55KWlqa2bk333xT1RccHMw111zDkSNHzMoM1E5Ohs2bN6PT6Xjvvff485//TEREBD4+Plx66aV9bICh1QVAYWEhV199NaGhoXh5eTF16lQeeOCBPuXq6+tZsmQJgYGBBAQEsHTpUlpbW83KbNy4kdNPP53AwEB8fX2ZOnWqVbRLJI5A/skskYwyTpw4wfz587nmmmu4/vrr1Wm5devW4evry4oVK/D19eW7775j5cqVNDY2mo0I9cfbb79NU1MTt956Kzqdjscff5wFCxZw6NChQUfvfvzxRz766CNuv/12/Pz8ePrpp7nyyis5fPgwY8eOBWDXrl1ceOGFjBs3jlWrVmEwGHj44YcJDQ09+ZfyM+vWrWPp0qXMmTOHRx99lKqqKv7973/z008/sWvXLnVa8MorryQ/P5/f/e53xMTEcPz4cTZu3Mjhw4fV4/PPP5/Q0FDuu+8+AgMDKS0t5aOPPhr0PdTV1XHXXXf1O2J54403snbtWj777DNOOeUUFi5cyI033khWVhZz5sxRy5WVlbF9+3azunvkkUf461//ytVXX81vfvMbqqureeaZZzjzzDPN9EH/7WQgamtr+5xzc3PrM536yCOPoNPpuPfeezl+/DhPPfUU8+bNIzc3Fy8vL2DodZGXl8cZZ5zBmDFjuOWWW4iJieHgwYP897//5ZFHHjF77tVXX82kSZN49NFH2blzJy+//DJhYWH84x//ACA/P5+LL76Y5ORkHn74YTw8PCguLuann34aVLtEokkUiUTilNxxxx1K71/hs846SwGU559/vk/51tbWPuduvfVWxdvbW2lvb1fPLV68WJk4caJ6XFJSogDK2LFjldraWvX8J598ogDKf//7X/Xcgw8+2McmQHF3d1eKi4vVc7t371YA5ZlnnlHPXXLJJYq3t7dy7Ngx9dyBAwcUNze3Pve0xOLFixUfH59+r3d2diphYWFKYmKi0tbWpp7/7LPPFEBZuXKloiiKUldXpwDKE0880e+9NmzYoABKVlbWoHb15KmnnlIAZcOGDf2Wqa2tVQBlwYIFiqIoSkNDg+Lh4aH88Y9/NCv3+OOPKzqdTikrK1MURVFKS0sVV1dX5ZFHHjErt2fPHsXNzc3s/EDtxBKmerX0b+rUqWq577//XgGU8ePHK42Njer59evXK4Dy73//W1GUodeFoijKmWeeqfj5+ak6TRiNxj72LVu2zKzMFVdcoYwdO1Y9/te//qUASnV19ZB0SyRaR06nSiSjDA8PD5YuXdrnvGkEBKCpqYmamhrOOOMMWltbKSwsHPS+CxcuJCgoSD0+44wzADh06NCgn503bx6xsbHqcXJyMv7+/upnDQYDmzZt4vLLLycyMlItFxcXx/z58we9/1DIzs7m+PHj3H777Xh6eqrnL7roIqZNm8bnn38OdL8nd3d3Nm/eTF1dncV7mUaJPvvsM7q6uoZsQ1NTEwB+fn79ljFda2xsBMDf35/58+ezfv16FEVRy7333nuccsopREdHA/DRRx9hNBq5+uqrqampUf9FREQwZcoUvv/+e7Pn9NdOBuLDDz9k48aNZv/Wrl3bp9yNN95opvGqq65i3LhxfPHFF8DQ66K6upotW7awbNkyVacJS1Psv/3tb82OzzjjDE6cOKG+S1O9ffLJJyMO3pFItIR04iSSUcb48eNxd3fvcz4/P58rrriCgIAA/P39CQ0NVYMiGhoaBr1v7y9Rk0PXn6Mz0GdNnzd99vjx47S1tREXF9ennKVzI6GsrAyAqVOn9rk2bdo09bqHhwf/+Mc/+PLLLwkPD+fMM8/k8ccfN0ujcdZZZ3HllVeyatUqQkJCuOyyy1i7di0dHR0D2mBybEzOnCUsOXoLFy7kyJEjbNu2DYCDBw+Sk5PDwoUL1TIHDhxAURSmTJlCaGio2b+CggKOHz9u9pz+2slAnHnmmcybN8/s36mnntqn3JQpU8yOdTodcXFx6prCodaFyclPTEwckn2DtdGFCxdy2mmn8Zvf/Ibw8HCuueYa1q9fLx06idMinTiJZJTRc8TNRH19PWeddRa7d+/m4Ycf5r///S8bN25U1woN5UvM1dXV4vmeo0O2+Kwj+P3vf8/+/ft59NFH8fT05K9//SsJCQns2rUL6HZKPvjgA7Zt28by5cs5duwYy5YtY/bs2TQ3N/d7X1P6l7y8vH7LmK5Nnz5dPXfJJZfg7e3N+vXrAVi/fj0uLi78v//3/9QyRqMRnU7HV1991We0bOPGjbzwwgtmz7HUTpydwdqZl5cXW7ZsYdOmTdxwww3k5eWxcOFCzjvvvD4BPhKJMyCdOIlEADZv3syJEydYt24dd911FxdffDHz5s0zmx51JGFhYXh6elJcXNznmqVzI2HixIkAFBUV9blWVFSkXjcRGxvLH//4R7755hv27t1LZ2cn//znP83KnHLKKTzyyCNkZ2fz1ltvkZ+fz7vvvtuvDaaoyLfffrtfp+H1118HuqNSTfj4+HDxxRfz/vvvYzQaee+99zjjjDPMpp5jY2NRFIVJkyb1GS2bN28ep5xyyiBvyHocOHDA7FhRFIqLi9Wo56HWhSkqd+/evVazzcXFhXPPPZfVq1ezb98+HnnkEb777rs+080SiTMgnTiJRABMIxQ9R746Ozt57rnnHGWSGa6ursybN4+PP/6Y8vJy9XxxcbHV8qWlpaURFhbG888/bzbt+eWXX1JQUMBFF10EdOdLa29vN/tsbGwsfn5+6ufq6ur6jCKmpKQADDil6u3tzd13301RUZHFFBmff/4569at44ILLujjdC1cuJDy8nJefvlldu/ebTaVCrBgwQJcXV1ZtWpVH9sUReHEiRP92mVtXn/9dbMp4w8++ICKigp1feNQ6yI0NJQzzzyTV199lcOHD5s9YySjuJaia4dSbxKJVpEpRiQSAZg7dy5BQUEsXryYO++8E51OxxtvvKGp6cyHHnqIb775htNOO43bbrsNg8HAs88+S2JiIrm5uUO6R1dXF//3f//X53xwcDC33347//jHP1i6dClnnXUWixYtUtNaxMTE8Ic//AGA/fv3c+6553L11Vczffp03Nzc2LBhA1VVVVxzzTUAvPbaazz33HNcccUVxMbG0tTUxEsvvYS/vz+//vWvB7TxvvvuY9euXfzjH/9g27ZtXHnllXh5efHjjz/y5ptvkpCQwGuvvdbnc7/+9a/x8/Pj7rvvxtXVlSuvvNLsemxsLP/3f//H/fffT2lpKZdffjl+fn6UlJSwYcMGbrnlFu6+++4hvcf++OCDDyzu2HDeeeeZpSgJDg7m9NNPZ+nSpVRVVfHUU08RFxfHzTffDHQnlB5KXQA8/fTTnH766aSmpnLLLbcwadIkSktL+fzzz4fcLkw8/PDDbNmyhYsuuoiJEydy/PhxnnvuOSZMmMDpp58+spcikTgSh8TESiSSk6a/FCMzZsywWP6nn35STjnlFMXLy0uJjIxU7rnnHuXrr79WAOX7779Xy/WXYsRSyg1AefDBB9Xj/lKM3HHHHX0+O3HiRGXx4sVm57799ltl1qxZiru7uxIbG6u8/PLLyh//+EfF09Ozn7fwC4sXL+43DUZsbKxa7r333lNmzZqleHh4KMHBwcp1112nHD16VL1eU1Oj3HHHHcq0adMUHx8fJSAgQMnIyFDWr1+vltm5c6eyaNEiJTo6WvHw8FDCwsKUiy++WMnOzh7UTkVRFIPBoKxdu1Y57bTTFH9/f8XT01OZMWOGsmrVKqW5ubnfz1133XUKoMybN6/fMh9++KFy+umnKz4+PoqPj48ybdo05Y477lCKiorUMgO1E0sMlGKkZ/sxpRh55513lPvvv18JCwtTvLy8lIsuuqhPihBFGbwuTOzdu1e54oorlMDAQMXT01OZOnWq8te//rWPfb1Th6xdu1YBlJKSEkVRutvXZZddpkRGRiru7u5KZGSksmjRImX//v1DfhcSiZbQKYqG/hSXSCSSXlx++eXk5+f3WWcl0R6bN2/mnHPO4f333+eqq65ytDkSyahHromTSCSaoa2tzez4wIEDfPHFF5x99tmOMUgikUg0jFwTJ5FINMPkyZNZsmQJkydPpqysjP/85z+4u7tzzz33ONo0iUQi0RzSiZNIJJrhwgsv5J133qGyshIPDw9OPfVU/v73v/dJHiuRSCQSkGviJBKJRCKRSJwQuSZOIpFIJBKJxAmRTpxEIpFIJBKJEyLXxA2C0WikvLwcPz8/dDqdo82RSCQSiUQyilEUhaamJiIjI3FxGXisTTpxg1BeXk5UVJSjzZBIJBKJRCIQR44cYcKECQOWkU7cIPj5+QHdL9Pf398mz8jOziYtLc0m99YyUrc4iKgZxNQtomaQukXC1pobGxuJiopS/Y+BkE7cIJimUP39/W3mxPn4+Njs3lpG6hYHETWDmLpF1AxSt0jYS/NQlnDJFCOD0NjYSEBAAA0NDTarNEVRhFxvJ3WLg4iaQUzdImoGqVskbK15OH6HjE7VALm5uY42wSFI3eIgomYQU7eImkHqFgktaZZOnAbo7Ox0tAkOQeoWBxE1g5i6RdQMUrdIaEmzXBOnAQIDAx1tgkOQusVBRM0gpm4RNYPtdSuKgl6vx2Aw2PQ5w8Xf35/29nZHm2FXTlazq6srbm5uVpmSlU6cBhgshHi0InWLg4iaQUzdImoG2+ru7OykoqKC1tZWmz1jpBiNRkpKShxthl2xhmZvb2/GjRuHu7v7Sd1HOnH9sGbNGtasWaP+1ZOdnY2Pjw+pqakUFBTQ1taGn58fkyZNIi8vD4CJEydiNBo5cuQIACkpKRQXF9Pc3IyPjw/x8fHs2rUL6P6Fd3V1paysjLq6Os466yxKS0tpbGzE09OTGTNmkJOTA0BkZCSenp4cOnQIgMTERI4ePUp9fT3u7u6kpKSwY8cOACIiIvD19aW4uBiAhIQEqqqqqK2txc3NjdmzZ7Njxw4URSE0NJSgoCD2798PwNSpU6mtraW6uhoXFxfmzJlDdnY2BoOBsWPHEhYWRkFBAQBTpkyhsbGRqqoqADIyMti5cyddXV0EBQURGRlJfn4+ALGxsbS2tlJRUQFAWloae/fupaKigpiYGKKjo9mzZw8AMTEx6PV6jh49CkBqaiqFhYW0trbi6+tLbGwsu3fvBiA6OhqAw4cPAzBz5kwOHjxIc3Mz3t7eTJs2jZ07d6rv283NjdLSUgCSkpI4fPgwDQ0NeHp6kpiYSHZ2NgDjxo3D29ubgwcPAjBjxgzKy8upq6tjzJgxpKamkpmZCUB4eDj+/v4cOHBAfd/Hjx/nxIkTuLq6kpaWRlZWFkajkdDQUIKDg9m+fTtBQUHEx8dTV1dHdXU1Op2O9PR0cnJy0Ov1BAcHEx4err7vuLg4mpubqaysBCA9PZ3c3Fw6OzsJDAxkwoQJ7N27F4DJkyfT3t5OeXk5ALNnzyY/P5/29nb8/f2JiYkxa7MGg0F937NmzWL//v20tLTg6+tLXFycuv4jKioKFxcXysrKAEhOTqakpISmpia8vLxISEhQ3/f48eNxd3dXOzq9Xk9ISAj19fV4eHiQnJxMVlaW2mZ9fHzU9z19+nQqKyupra3t877DwsIICAhQ3/e0adOoqamhpqZGbbOm9x0SEkJISAiFhYVqm21oaOD48eN92mxwcDARERHs27dPbbMtLS3q+54zZw55eXl0dHQQGBhIVFSU2mYnTZpEZ2cnx44dU9usqY9oamritNNOO+k+wvS+naGP2Lp1K0FBQVbpI9rb2wkICHCKPiIzMxNfX1+r9BFFRUUAxMfHU1tbS319Pd7e3kRGRmIwGFAUBTc3N9zc3NQRIQ8PD4xGI11dXUC3k9De3o7RaOxT1t3dHUVRLJZ1dXXF3d2dtrY2i2W9vLzo7OzEYDDg6upqNjI4ZswYdDqdOt3Ys6yLiwseHh7qfXuX9fT0RK/Xo9frcXFxwdPTU3VaByqr0+nw9vampaVFLevi4kJHR4f6XgwGg1nZ1tZW9R26urqale35Dn18fMzKmt6h0WjE09Oz33c42Pt2dXWlsrKSPXv24OnpyeTJk836iIaGBoaKjE4dBHtEp2ZmZpKRkWGTe2sZqVscRNQMYuoWUTPYTnd7ezslJSVMnDgRb29vq9//ZGlubsbX19fRZtgVa2hubW2lrKyMSZMm4enpaXZNRqc6GZMnT3a0CQ5B6hYHETWDmLpF1Ay21z3Y9kuOwsPDw9Em2B1raLZWfWqzVQiGaItCTUjd4iCiZhBTt4iaQVzdIk7maUmzdOIczLb8cj75XwHV9dpbsGprTGu2RENE3SJqBjF1i6gZxNWtpXQb9kJLmqUT52Du+88Wnv7iCM9tyHW0KRKJRCKRaJaYmBieeuopR5uhKaQT52BqGrojdb7MFCtEG7qjJkVERN0iagYxdYuoGcTV7ePj0+ecTqcb8N9DDz00omdlZWVxyy23nJS9Z599Nr///e9P6h6WNDsK6cQ5mNBALwAuzIhxrCEOwJReQDRE1C2iZhBTt4iaQVzdppQhPamoqFD/PfXUU/j7+5udu/vuu9WypiTGQyE0NFQTEbqWNDsK6cQ5mOCAbifurJlRDrbE/oi6EFhE3SJqBjF1i6gZ7KtbURRa27vs/s/Sgn6j0djnXEREhPovICAAnU6nHhcWFuLn58eXX37J7Nmz8fDw4Mcff+TgwYNcdtllhIeH4+vry5w5c9i0aZPZfXtPp+p0Ol5++WWuuOIKvL29mTJlCp9++ulJvdsPP/yQGTNm4OHhQUxMDP/85z/Nrj/33HMkJyfj6elJeHg4V111lXrtgw8+ICkpCS8vL8aOHcu8efPU/HW2Qib7dTAnv+mG82KrvHtaR0TdImoGMXWLqBnsq7utQ0/Kstft9jwTua/eiLfnGLNzrq6uI7rXfffdx5NPPsnkyZMJCgriyJEj/PrXv+aRRx7Bw8OD119/nUsuuYSioiI1abMlVq1axeOPP84TTzzBM888w3XXXUdZWRnBwcHDtiknJ4err76ahx56iIULF7J161Zuv/12xo4dy5IlS8jOzubOO+/klVde4eyzz6a2tpYffvgB6B59XLRoEY8//jhXXHEFTU1N/PDDDzaPZJVOnEbQUMSy3YiJiXG0CQ5BRN0iagYxdYuoGcTVPdKcaQ8//DDnnXeeehwcHMzMmTPV47/97W9s2LCBTz/9lOXLl/d7nyVLlrBo0SIA/v73v/P000+zY8cOLrzwwmHbtHr1as4991z++te/At07Zuzbt48nnniCJUuWcPjwYXx8fLj88ssJCAhg4sSJzJo1C+h24vR6PQsWLGDixIlA964ftkY6cQ7GtAGugnheXF5enpCZ3UXULaJmEFO3iJrBvrq9PNzIffXGYX/u/e+LePWLvSz7dSL/75ypI3pub0zbnQ2XtLQ0s+Pm5mYeeughPv/8c9UhamtrU7dM64/k5GT1Zx8fH/z9/dWt9YZLQUEBl112mdm50047jaeeegqDwcB5553HxIkTiY2NZf78+Vx44YXqVO7MmTM599xzSUpK4oILLuD888/nqquuIigoaES2DBXpxDmYn304IUfiJBKJRDJ8dDpdn2nNobB4fiKL5yfawKLh0zvC8+6772bjxo08+eSTxMXF4eXlxVVXXTVoTrYxY8zfg06ns7hOzxr4+fmxc+dOvvzyS3744QdWrlzJQw89RFZWFoGBgWzcuJGtW7fyzTff8Mwzz/DAAw+QmZnJpEmTbGIPyMAGhyPymjjTkLNoiKhbRM0gpm4RNYO4uq217dZPP/3EkiVLuOKKK0hKSiIiIoLS0lKr3HuoJCQk8NNPP/WxKz4+Xl375+bmxoUXXsjjjz9OXl4epaWlfPfdd0C3A3naaaexatUqdu3ahbu7Oxs2bLCpzXIkTiNoaRsPe2EwGBxtgkMQUbeImkFM3SJqBnF1W+u7a8qUKXz00Udccskl6HQ6/vrXv9psRK26uprc3Fyzc+PGjeOPf/wjc+bM4W9/+xsLFy5k27ZtPPvsszz33HMAfPbZZxw6dIhTTz2VsLAwvvjiC4xGI1OnTiUzM5Nvv/2W888/n7CwMDIzM6muriYhIcEmGkzIkTgH88uaOPE4evSoo01wCCLqFlEziKlbRM0grm5rbUG1evVqgoKCmDt3LpdccgkXXHABqampVrl3b95++21mzZpl9u+ll14iNTWV9evX8+6775KYmMjKlSt5+OGHWbJkCQCBgYF89NFHXHDBBSQkJPD888/zzjvvMGPGDPz9/dmyZQu//vWviY+P5y9/+Qv//Oc/mT9/vk00mJAjcQ5GromTSCQSyWhlyZIlqhME3TsmWBq9i4mJUaclTdxxxx1mx72nVy3dp76+fkB7Nm/ePOD1K6+8kiuvvNLitdNPP53NmzfT3NzcJ5gjISGBr776asB72wI5EudwxF0VZwrNFg0RdYuoGcTULaJmEFe3FnZQsDda0iydOM0g3lDc/v37HW2CQxBRt4iaQUzdImoGcXV3dHQ42gS7oyXN0olzMCJPp9p6OxKtIqJuETWDmLpF1Azi6hYxoENLmqUT52BEduJGkiByNCCibhE1g5i6RdQM4uoe6bZbzoyWNAvhxH322WdMnTqVKVOm8PLLLzvaHDN0Aq+Ji4uLc7QJDkFE3SJqBjF1i6gZbK9bq2morJUnzpmwhmZr1eeod+L0ej0rVqzgu+++Y9euXTzxxBOcOHHC0Wb1QcRtt3rn6REFEXWLqBnE1C2iZrCdbtOOBK2trTa5/8miVbtsiTU0m+7Re8eJ4TLqU4zs2LGDGTNmMH78eADmz5/PN998o26Y62hEnk6VSCQSycC4uroSGBio7gfq7e2t5hfVAh0dHbi5jXpXwoyT0awoCq2trRw/fpzAwMCTnprV/JvfsmULTzzxBDk5OVRUVLBhwwYuv/xyszJr1qzhiSeeoLKykpkzZ/LMM8+Qnp4OQHl5uerAAYwfP55jx47ZU8KAHK/r9sZ/2nOMi06d7GBr7EtUVJSjTXAIIuoWUTOIqVtEzWBb3REREQAj3tjdlhgMBk2tEbMH1tAcGBio1uvJoHknrqWlhZkzZ7Js2TIWLFjQ5/p7773HihUreP7558nIyOCpp57iggsuoKioiLCwsGE/r6Ojwyx8uLGx8aTsH4zyE80AfJtTBpxh02dpDReXUT+bbxERdYuoGcTULaJmsK1unU7HuHHjCAsLo6ury2bPGQnV1dWEhoY62gy7crKax4wZYzXHV/NO3Pz58wfctmL16tXcfPPNLF26FIDnn3+ezz//nFdffZX77ruPyMhIs5G3Y8eOqaN0lnj00UdZtWpVn/PZ2dn4+PiQmppKQUEBbW1t+Pn5MWnSJPLy8oDuDZCNRiNHjhwBICUlheLiYpqbm/Hx8SE+Pp5du3YBMGHCBFxdXQn2caOyvpOzZo6noKCAxsZGPD09mTFjBjk5OQBERkbi6enJoUOHAEhMTOTo0aPU19fj7u5OSkoKO3bsALr/YvP19aW4uBjoziJdVVVFbW0tbm5uzJ49mx07dqAoCqGhoQQFBan5jaZOnUptbS3V1dW4uLgwZ84csrOzMRgMjB07lrCwMAoKCoDufe4aGxupqqoCICMjg507d9LV1UVQUBCRkZHk5+cDEBsbS2trKxUVFQCkpaWxd+9eKioqiImJITo6mj179gDdWbv1er26hU1qaiqFhYW0trbi6+tLbGwsu3fvBiA6OhqAw4cPAzBz5kwOHjxIc3Mz3t7eTJs2jZ07d6rv283NTc34nZSUxOHDh2loaMDT05PExESys7OB7j30vL29OXjwIAAzZsygvLycuro6xowZQ2pqKpmZmQCEh4fj7+/PgQMH1Pd9/PhxTpw4gaurK2lpaWRlZWE0GgkNDSU4OJjc3FyCgoKIj4+nrq6O6upqdDod6enp5OTkoNfrCQ4OJjw8XH3fcXFxNDc3U1lZCUB6ejq5ubl0dnYSGBjIhAkT2Lt3LwCTJ0+mvb2d8vJyAGbPnk1+fj7t7e34+/sTExNj1mYNBoP6vmfNmsX+/ftpaWnB19eXuLg4da1PVFQULi4ulJWVAZCcnExJSQlNTU14eXmRkJCgvu/x48fj7u5OSUkJ0L02taGhgfr6ejw8PEhOTiYrK0ttsz4+Pur7nj59OpWVldTW1vZ532FhYQQEBKjve9q0adTU1FBTU6O2WdP7DgkJISQkhMLCQrXNNjQ0qKMZPdtscHAwERER7Nu3T22zLS0t6vueM2cOeXl5dHR0EBgYSFRUlNpmJ02aRGdnp9rP9OwjmpqaCAgIOKk+ouf7Li0t1XwfYWrf1ugj2tvbCQgIcIo+Ij8/n7KyMqv0EUVFRQAW+4jc3FxN9RGNjY3qfU6mj0hKSuLIkSNO0UfU1dWRlpZ2Un1Efn5+v35EQ0MDQ0ZxIgBlw4YN6nFHR4fi6upqdk5RFOXGG29ULr30UkVRFKWrq0uJi4tTjh49qjQ1NSnx8fFKTU1Nv89ob29XGhoa1H9HjhxRAKWhocEWkpQb/u9zZcqil5VPfyy2yf21zPbt2x1tgkMQUbeImhVFTN0ialYUqVskbK25oaFhyH6H5kfiBqKmpgaDwUB4eLjZ+fDwcNWzdnNz45///CfnnHMORqORe+65h7Fjx/Z7Tw8PD7uGTIucYiQ5OdnRJjgEEXWLqBnE1C2iZpC6RUJLmoVYvHDppZeyf/9+iouLueWWWxxtjhm/RKeKF55qGkYXDRF1i6gZxNQtomaQukVCS5qdeiQuJCQEV1dXdc2FiaqqqpOO+lizZg1r1qxRt9ew1Zo409x3R2enkGviXFxcnGK9izXXxB0+fJimpibh1sQVFRU5xXoXU5u11pq4k+0jnG1NnKl9i7YmrqKigqamJpuvidNaH9HY2KhqFWlNXFhYmFX6iJNdE6dTnGgISKfT9UkxkpGRQXp6Os888wwARqOR6Oholi9fzn333XfSz2xsbCQgIICGhgb8/f1P+n69Wfrol/y0p5zHbzuTy8+YYvX7a5m8vDxNDUvbCxF1i6gZxNQtomaQukXC1pqH43dofiSuublZ/YsRuocxc3NzCQ4OJjo6mhUrVrB48WLS0tJIT0/nqaeeoqWlRY1W1Toir4lLSEhwtAkOQUTdImoGMXWLqBmkbpHQkmbNr4nLzs5m1qxZzJo1C4AVK1Ywa9YsVq5cCcDChQt58sknWblyJSkpKeTm5vLVV1/1CXbQKiLv2GAaVhcNEXWLqBnE1C2iZpC6RUJLmjU/Enf22WcPuuh/+fLlLF++3KrPtdeauPqf5747BVwTV1dXR2FhoVOsd7Hmmri6ujoyMzOdYr2LXBN38mvi2trahFoTZ2rfoq2Ja25uJjMzU7g1cUajUcg1cTU1NXJNnDNg6zVxN/3ja37YfZTHbj2DBWfFW/3+Wubo0aNMmDDB0WbYHRF1i6gZxNQtomaQukXC1pqH43dofjp1tFNR073t1vZ9FQ62xP64u7s72gSHIKJuETWDmLpF1AxSt0hoSbN04hzMkeNNAGzedcTBltgfLeXasSci6hZRM4ipW0TNIHWLhJY0SyfOwUSH+wFwVkqUgy2RSCQSiUTiTGg+sMFR2Cuwwcet+/6z4sYKF9ig1+uFDGzQ6/XCBTbExcUJGdjg6ekpXGCDqX2LFtjg6+srZGDD1KlThQts0Ov1MrDBWbB1YMOtT3zD97uO8MjNp/P/zplq9ftrmaKiIqZOFUsziKlbRM0gpm4RNYPULRK21iwDG5wI3c+J4kT0pOvr6x1tgkMQUbeImkFM3SJqBqlbJLSkWTpxDuZYdXdgQ+a+cgdbYn88PDwcbYJDEFG3iJpBTN0iagapWyS0pFlOpw6CradTk5eso73TQICvB1kvXm/1+2sZo9GIi4t4f0eIqFtEzSCmbhE1g9QtErbWLKdTnYiJEd0VdGbyeAdbYn9Mi1ZFQ0TdImoGMXWLqBmkbpHQkmYZndoP9opO9R2jB2BmrHjRqXLbLe1Hnsltt+S2W3LbLbntltx2S2675bTYejr19tWb2JRdxqplc1k0L8Hq99cyZWVlTJw40dFm2B0RdYuoGcTULaJmkLpFwtaa5XSqE/FzcKqQ+Pj4ONoEhyCibhE1g5i6RdQMUrdIaEmzdOIczNGft93aUVDpYEvsj2loXDRE1C2iZhBTt4iaQeoWCS1plk6cgymp6J77/jHvqIMtkUgkEolE4kxIJ87BTI4MAOC0JPGiU6dPn+5oExyCiLpF1Axi6hZRM0jdIqElzdKJczDRYd2LFtOmRTjYEvtjiuoRDRF1i6gZxNQtomaQukVCS5plipF+sFeKkdq6WgA6OzuFSzFSUVGBwWBwivQB1kwxcvDgQWpra50ifYA1U4wYjUanSB9garPWSjEyYcIEoVKMmNq3aClGDh8+bLHNjvYUI42NjdTWdn+PiZRiJCgoSKYYcQZsnWLkigc+Jr/kBPNPmcS/7/yV1e+vZXbu3ElqaqqjzbA7IuoWUTOIqVtEzSB1i4StNQ/H75BO3CDYbdstH3eyXrrB6veXSCQSiUTiPMg8cU5E7PhAAOYmihfYYBr+Fg0RdYuoGcTULaJmkLpFQkuapRPnYGIiuqNTU+PDHGyJRCKRSCQSZ0I6cQ7GtGODiHPaYWFiOq4i6hZRM4ipW0TNIHWLhJY0SydO4jACAgIcbYJDEFG3iJpBTN0iagapWyS0pFk6cQ6mrLIRgJyiKgdbYn9MIeCiIaJuETWDmLpF1AxSt0hoSbPME9cP9soTt//wCQC27T0mXJ64uro6CgsLnSIHlDXzxNXV1ZGZmekUOaCsmSeuqKjIKXJAmdqstfLEtbW1CZUnztS+RcsT19zcTGZmpnB54oxGo6pVpDxxNTU1Mk+cM2DrFCNX/fUT8g7WcP6cGJ79w7lWv7+WaWho0NSwtL0QUbeImkFM3SJqBqlbJGytWaYYcSJixnU3hFkCRqfW1NQ42gSHIKJuETWDmLpF1AxSt0hoSbN04hyMDnHDU7X0i2BPRNQtomYQU7eImkHqFgktaZZOnINRU4wIOKvt4iJm8xNRt4iaQUzdImoGqVsktKRZrokbBFuvibt65afkFldzXtpE1qyYZ/X7SyQSiUQicR7kmjgnouhIHQCZ+yocbIn9MUUeiYaIukXUDGLqFlEzSN0ioSXN0olzMFOjggBInx7hYEvsj9FodLQJDkFE3SJqBjF1i6gZpG6R0JJm6cQ5mNjxgQDMjBMvOjUkJMTRJjgEEXWLqBnE1C2iZpC6RUJLmqUT52B0P0c2iLgyUUu/CPZERN0iagYxdYuoGaRukdCSZunESRyGKUu2aIioW0TNIKZuETWD1C0SWtIst93qB3ttu1VdXQ1AV1eX3HYL7W6pI7fdkttuyW235LZbctstue2W3HbLybB1ipEHXvqB97/fzx+uns1tl6dY/f5apra2luDgYEebYXdE1C2iZhBTt4iaQeoWCVtrlilGnIiDx7o97t3F1Q62xP4M56+N0YSIukXUDGLqFlEzSN0ioSXN0olzMIVlJwDIKqx0sCX2xzRkLRoi6hZRM4ipW0TNIHWLhJY0SyfOwSTEjAUgbWq4gy2RSCQSiUTiTMg1cYNg6zVxK1/5iXe/LeTOq1JZvmCW1e8vkUgkEonEeZBr4pwI3c//i+hLm6KURENE3SJqBjF1i6gZpG6R0JJm6cQ5mJ9z/QqZ7Lerq8vRJjgEEXWLqBnE1C2iZpC6RUJLmqUT52hMOzY42AxHIFpYugkRdYuoGcTULaJmkLpFQkuapRPnYEzTqSIOxUVERDjaBIcgom4RNYOYukXUDFK3SGhJs3TiHIzIe6easuOLhoi6RdQMYuoWUTNI3SKhJc3SiXMwB47WArC3RLxkvxKJRCKRSEaOdOIczN5D3cl+d+7XTvJAexEbG+toExyCiLpF1Axi6hZRM0jdIqElzdKJczBJk0MAmDUlzMGW2J+WlhZHm+AQRNQtomYQU7eImkHqFgktaZZOnIOZNrE7ysW0c4NIVFaKt9UYiKlbRM0gpm4RNYPULRJa0uzmaAO0ypo1a1izZg0GgwGA7OxsfHx8SE1NpaCggLa2Nvz8/Jg0aRJ5eXkATJw4EaPRyJEjRwBISUmhuLiY5uZmfHx8iI+PZ9euXQBMmDABV1dXtTF0dekpKCigsbERT09PZsyYQU5ODgCRkZF4enpy6NAhABITEzl69Cj19fW4u7uTkpLCjh07gO6oGV9fX4qLiwFISEigqqqK2tpa3NzcmD17Njt27EBRFEJDQwkKCmL//v0ATJ06ldraWqqrq3FxcWHOnDlkZ2djMBgYO3YsYWFhFBQUADBlyhQaGxupqqoCICMjg507d9LV1UVQUBCRkZHk5+cD3UPPra2tVFRUAJCWlsbevXupq6ujsLCQ6Oho9uzZA0BMTAx6vZ6jR48CkJqaSmFhIa2trfj6+hIbG8vu3bsBiI6OBuDw4cMAzJw5k4MHD9Lc3Iy3tzfTpk1TkzJOmDABNzc3SktLAUhKSuLw4cM0NDTg6elJYmIi2dnZAIwbNw5vb28OHjwIwIwZMygvL6euro4xY8aQmppKZmYmAOHh4fj7+3PgwAH1fR8/fpwTJ07g6upKWloaWVlZGI1GQkNDCQ4Opq6ujszMTOLj46mrq6O6uhqdTkd6ejo5OTno9XqCg4MJDw9X33dcXBzNzc1qe0lPTyc3N5fOzk4CAwOZMGECe/fuBWDy5Mm0t7dTXl4OwOzZs8nPz6e9vR1/f39iYmLM2qzBYFDf96xZs9i/fz8tLS34+voSFxdHbm4uAFFRUbi4uFBWVgZAcnIyJSUlNDU14eXlRUJCgvq+x48fj7u7OyUlJQDo9XqKioqor6/Hw8OD5ORksrKy1Dbr4+Ojvu/p06dTWVlJbW1tn/cdFhZGQECA+r6nTZtGTU0NNTU1aps1ve+QkBBCQkIoLCxU22xDQ4O672HPNhscHExERIS6YDk2NpaWlhb1fc+ZM4e8vDw6OjoIDAwkKipKbbOTJk2is7OTY8eOqW3W1Ec0NTXR1tZ2Un1Ez/ddWlqq+T7C1L6t0Ue0t7cTEBDgFH1Ec3MzmZmZVukjioqKAJyijzAajarWk+kjkpKSOHLkiFP0EXV1ddTU1Filj7DkRzQ0NDBU5LZbg2DrbbcefTOTtV/s5eZLkvjTonSr31/LGI1GXFzEGwwWUbeImkFM3SJqBqlbJGytWW675USIvGOD6S8P0RBRt4iaQUzdImoGqVsktKRZOnEORoe4eeI6OjocbYJDEFG3iJpBTN0iagapWyS0pFk6cQ5GHYkTcOOtwMBAR5vgEETULaJmEFO3iJpB6hYJLWmWTpyDKSjrzhNXUFrrYEvsT1RUlKNNcAgi6hZRM4ipW0TNIHWLhJY0SyfOwewu7t6pIe+geDs2mCJ3evPEO1mk3vQ6r3+Vb2eL7EN/ukczImoGMXWLqBmkbpHQkmbpxDmYlJ+T/CbFhjjYEu3w0n/zaG7r4t/v5zjaFIlEIpFINIt04hzMjJhu521qVLCDLbE/kyZNGvB6VLifnSyxL4PpHo2IqBnE1C2iZpC6RUJLmqUT52B+STEiXmBDZ2fngNejw6yfl08LDKZ7NCKiZhBTt4iaQeoWCS1plk6cg/klOlU8TNmr+2O0vpPBdI9GRNQMYuoWUTNI3SKhJc3SiXMwIueJk0gkEolEMnKkE+dgRJ5OTU1NdbQJDkFE3SJqBjF1i6gZpG6R0JJm6cQ5mp+9OPFcONSNm3tS29iu/rxzf5U9zbEblnSPdkTUDGLqFlEzSN0ioSXN0olzMDrTDwKOxLW1tfU59/jbO9Sfq+v7Xh8NWNI92hFRM4ipW0TNIHWLhJY0SyfOwZimUz/+sZh3NmnHu7cHfn59U4hU1rY4wJJuCg/X8u63hRiNtnWoLeke7YioGcTULaJmkLpFQkuapRPnYHQ/e3Gt7Xpe+DTPwdbYF0u5dnQ6CwXtxKX3bWDlKz/xwEs/2PQ5WsoxZC9E1Axi6hZRM0jdIqElzdKJczAmn8Xbw41bL012qC32Ji+vr9Oqw4Fe3M98tvWQTe9vSfdoR0TNIKZuETWD1C0SWtIsnTgHYxp5unhuLIvmJTjWGA3gyJG4XxBvfaJEIpFInA/pxDkYnRqdKp7jMHHixD7ntBDfYTAqNl2faEn3aEdEzSCmbhE1g9QtElrSLIQTd8UVVxAUFMRVV13laFP6RQvOi70xGo2ONsEieoNi0/WJWtVtS0TUDGLqFlEzSN0ioSXNQjhxd911F6+//rqjzbCIThvzhw7hyJEjfc5p5XXYcn2iJd2jHRE1g5i6RdQMUrdIaEmzEE7c2WeframQ4J6IvGODRTTixcn1iRKJRCLROg534rZs2cIll1xCZGQkOp2Ojz/+uE+ZNWvWEBMTg6enJxkZGezYsaPvjZwUk8vyxfYS4fLEpaSkONqEfrFlXWhZt60QUTOIqVtEzSB1i4SWNDvciWtpaWHmzJmsWbPG4vX33nuPFStW8OCDD7Jz505mzpzJBRdcwPHjx9UyKSkpJCYm9vlXXl5uLxkjxjSd2tYhXp644uLiPue0MQ6HTevCku7RjoiaQUzdImoGqVsktKTZzdEGzJ8/n/nz5/d7ffXq1dx8880sXboUgOeff57PP/+cV199lfvuuw+A3Nxcq9nT0dFBR0eHetzY2Gi1e1vC5LR4CZgnrrm5uc85rUwq27IuLOke7YioGcTULaJmkLpFQkuaHe7EDURnZyc5OTncf//96jkXFxfmzZvHtm3bbPLMRx99lFWrVvU5n52djY+PD6mpqRQUFNDW1oafnx+TJk1SE/9NnDgRo9GoLnpMSUmhuLiY5uZmfHx8iI+PZ9euXQBMmDABV1dXDv9c9tzUCaSMh8zMTDw9PZkxYwY5OTkAREZG4unpyaFD3UloExMTOXr0KPX19bi7u5OSkqJOMUdERODr66v+pZCQkEBVVRW1tbW4ubkxe/ZsduzYgaIohIaGEhQUxP79+wGYOnUqtbW1VFdX4+Liwpw5c8jOzsZgMDB27FjCwsLUjX+nTJlCY2MjVVXdm9RnZGSwc+dOurq6CAoKIjIykvz8fABiY2NpbW2loqICgLS0NPbu3UtjYyOFhYVER0ezZ88eADp7ONAAXV1dFBYW0traiq+vL7GxsezevRuA6OhoAA4fPgzAzJkzOXjwIM3NzXh7ezNt2jR27typvm83NzdKS0sBSEpK4vDhwzQ0NODp6UliYqLZc5ubm8nMzARgxowZlJeXU1dXx5gxY0hNTVWvhYeH4+/vz4EDB9T3ffz4cU6cOIGrqytpaWlkZWVhNBoJDQ0lODiYxsZGMjMziY+Pp66ujurqanQ6Henp6eTk5KDX6wkODiY8PFx933FxcTQ3N1NZWQlAeno6ubm5dHZ2EhgYyIQJE9i7dy8AkydPpr29XR2Jnj17Nvn5+bS3t+Pv709MTIxZmzUYDBw9ehSAWbNmsX//flpaWvD19SUuLk79IykqKgoXFxfKysoASE5OpqSkhKamJry8vEhISFDf9/jx43F3d6ekpAQAd3d3ioqKqK+vx8PDg+TkZLKystQ26+Pjw8GDBwGYPn06lZWV1NbW9nnfYWFhBAQEqO972rRp1NTUUFNTo7ZZ0/sOCQkhJCSEwsJCtc02NDSoo/g922xwcDARERHs27dPbbMtLS3q+54zZw55eXl0dHQQGBhIVFSU2mYnTZpEZ2cnx44dAzDrI9rb22lrazupPqLn+y4tLaWxsVHTfYSpfVujj2hvbycgIMCsj4iJiUGv16ttNjU11W59RHZ2NgDjxo3D29tbbbMzZsygs7OTzMxMq/QRRUVFAE7RR3h6eqpaT6aPSEpK4siRI07RRzQ2NlJTU2OVPsKSH9HQ0MBQ0SkaWlGv0+nYsGEDl19+OQDl5eWMHz+erVu3cuqpp6rl7rnnHv73v/+plTYY8+bNY/fu3bS0tBAcHMz7779vdr+eWBqJi4qKoqGhAX9//5GL64dXPt/DP97awaWnx/Lk7Wdb/f5aprOzE3d3d7Nzv/nH12zZfVQ93v/2TXazJ/7aV9SfI0N82fz0Qps8x5Lu0Y6ImkFM3SJqBqlbJGytubGxkYCAgCH5HQ5fE2cPNm3aRHV1Na2trRw9erRfBw7Aw8MDf39/s3+2RF0DphlX2n6YRhx6opHgVJtOp1rSPdoRUTOIqVtEzSB1i4SWNGvaiQsJCcHV1VUdjjdRVVVFRESEg6yyLuqODQI6cVpGphiRSCQSidbR9Jo4d3d3Zs+ezbfffqtOsRqNRr799luWL19u02evWbOGNWvWYDAYANutiSs80L2G5b9bD7J97xFWL43X9HoXa66Ja2tr67smrrPTrB4ctSbupY+zSB7XnZXb2mvi2trahFsTFxoaKuSaOEC4NXGm9i3amjhXV1ch18SFh4cLtyaura1Nrokz0dzcrHYms2bNYvXq1ZxzzjkEBwcTHR3Ne++9x+LFi3nhhRdIT0/nqaeeYv369RQWFhIeHm5z+4YzNz0SLr1/A4VlteqxPdeAOZrKyso+I6q3PPENm3f9kg17NK6Js6R7tCOiZhBTt4iaQeoWCVtrdqo1cdnZ2cyaNYtZs2YBsGLFCmbNmsXKlSsBWLhwIU8++SQrV64kJSWF3NxcvvrqK7s4cPbAHkvA/vziFmYueY03vsq3w9OGjumvtZ5oJc4mNT7MZve2pHu0I6JmEFO3iJpB6hYJLWl2+HTq2WefPegX9/Lly20+feoo3Fxt70d/sLl7WPnfH+zkhgtn2Px5J4NW9pL9YfcxR5sgkUgkEsmAONyJ0yr2WhPX1tZi9lxb5IkzMTHUA0Aza+IMBoNm88R1dHbZLE+cwWAQbk3clClThFwT5+XlJdyaOFP7Fm1NnK+vr5Br4qZNmybcmjiDwSDXxDkLtl4Tt2jVZ+QU/RJ9a4s1YKa1XjdfksyfFs2x+v1HSkFBgZmTCXDrE9/wvQbWxHl5uLF77WKbPMeS7tGOiJpBTN0iagapWyRsrdmp1sSJjsFgdLQJDsPSlmZamU51d3O12b2Lj9Tw3neFdOoNNnuG1rD19nVaRUTdImoGqVsktKRZOnEOpktv7sSdfee7NnuWRvwjFU9Pzz7netv4zqYCO1ljzhkzx9vs3n9ct5+/vvwTd/37O5s9Q2tYqmsREFG3iJpB6hYJLWmWTpyDWbnEfPeI8pqWfkqOPmbM6Btk0Xsk7oVP8+xljhmbsg/b/Bk/5okTPGGprkVARN0iagapWyS0pFkGNvSDvQIbOuvKmBQ6hpLqLvXZu3fvtklgg2kBplYCGyoqKoiJiTFbtNzVaR7Y8JuLZrBnzx67Bza0d+ptFthgwsfLjUOHDml60bK1Ahv0ej0hISFOsWjZ1GatsWi5qamJ0047TajAhq1btxIUFCRcYENmZia+vr7CBTY0Njbi4tI9HiRKYENdXR1paWkysMEZsHVgA8BbH3/PqvWH1GNrL+Y3Ldi/5dJk7r5GO4ENmZmZZGRkmJ276+nv+HJ7iXq8atlcu22B1TOwAWwXVGF6jscYV/a8tsQmz9AalupaBETULaJmkLpFwtaaZWCDk3HqzMlmx7OWvW6TtWAvfprnsDVmloiMjOxzzkUj06me7rYfpJ4Q6mfzZ2gFS3UtAiLqFlEzSN0ioSXN0onTAH6+3mbHLe1dPPlulk2e5SinyBKWFoe6upg7cbdemmwvc8yYMiHQ5s8YN9bH5s/QClpaCGxPRNQtomaQukVCS5qlE6cBTOtYzLDRJLejnCJLWNTdKzrVXlOpvdlbUuOQ545WLNa1AIioW0TNIHWLhJY0SydOo6ROtc3muo5yioZK7+lUR03/6uywq23FCXEikSUSiURifWR0aj/YKzq1rKwMvV7f5/kFpTVqtI21olNNaCU6Va/XD7rt1guf7CYxXG/zyLNJcVPNnmtUFJtHpx4sr+fx1zdz1lQvzUaeWSs6NTY2Vshttzw9PYXbdkuv1wu57ZaPj4+Q227Fx8cLt+2WXq+X2245C/aITi0qKuI3T++kqq5VPfe7BbP43VWpVrl/z6hLe25jNRhFRUVMnWruPD3w0g+8//1+9fjiuZNZvfwcm9tytLqJX9213uycraNToXvksfCtZTZ5jpawVNciIKJuETWD1C0SttYso1OdjPr6epbMN89TtkeANVn19fV9zvV04AB27j9uF1uMRsf8LWMU5G8oS3UtAiLqFlEzSN0ioSXNI3Lijhw5og6tQvf03O9//3tefPFFqxkmEu7u7lx2RpzZuS25R/spPXpwd3cfQin7ODm9tz+TWJeh1fXoQ0TdImoGqVsktKR5RE7ctddey/fffw907wJw3nnnsWPHDh544AEefvhhqxooAikpKYQEeJmdE2GEJiUlZdAy9tqGTG+QTpwtGUpdj0ZE1C2iZpC6RUJLmkfkxO3du5f09HQA1q9fT2JiIlu3buWtt95i3bp11rRPCEwLjkXDFrqP17WSua+CtN+8Tvy1r3Dm8ncoPlpHU2snnXoDf39jO39+8QeSl6wj/tpXiL/2Fc5a/u6wnLiG5g4qa1uQy0mHjmzj4iCiZpC6RUJLmkcUndrV1YWHhwcAmzZt4tJLLwW6o0BM0UXOjj2jU+vq6mhra+tjw2iPTq2rq+sTndobN1ddv3unjvEN4csdR3nru4PoDZYdqsraVn59z0cWr5moqG3higc+6XM+/tpXSInx4y9LTuOqh76y+FkdfSd8p4z3ZfmF4wkN8BwwOtVEZmamZiPPrLl3qojRqU1NTcJFp9bV1QkZndrc3CxkdKrRaBQuOrWurs65o1MzMjI455xzuOiiizj//PPZvn07M2fOZPv27Vx11VVm6+WcHXtEp5aVlTFx4kSm3/CqmTNirehIR0enPrT2Jz7cfIA7FqTw28tS1PMm3T3pvX8pWLZ51/4qFj70mdVttTY6HVx06i8Rtk2tncz+zRtmZbQUMWwrLNW1CIioW0TNIHWLhK012zw69R//+AcvvPACZ599NosWLWLmzJkAfPrpp+o0q2To+Pr6ArDvjWWYct2GBnoN8ImRY6vtvAbi7Y2FdHQZWPNRrtl5k+7h0Kk3EH/tK07hwAEoCny29RDx177CHf/a1MeBE4WR1PVoQETdImoGqVsktKR5RNOpZ599NjU1NTQ2NhIUFKSev+WWW/D29h7gkxJLFBcXM3bsWMDy9Jw1WffFXu6+Zo4Nn9A/hl7rznrqHoyp175ipzhV27Exq8zRJjiM4dT1aEJE3SJqBqlbJLSkeUROXFtbG4qiqA5cWVkZGzZsICEhgQsuuMCqBorGGDcXOrqMTJvYd+2UNRg31nF/Qegt5GLLPXCcE41tnDFzAms/39vvZ4frwA13inLFs9/z+bZDeLq7cd916QNuT2ZpylcikUgkEnszIifusssuY8GCBfz2t7+lvr6ejIwMxowZQ01NDatXr+a2226ztp2jmp7BBx1d3aNV2/NtEyASFqSNkdL/e20b678vor2zf8fNxGD7p1pjTdnq5ecMeWeI/p539p3v2i0lirNhKcBGBETULaJmkLpFQkuaR7QmbufOnZxxxhkAfPDBB4SHh1NWVsbrr7/O008/bVUDRcAUvdWTLr2RFc9+7wBrbM8DL/3A61/vo73TMKTyD7661ex4/9s3mf3TCpufvka16eK5kx1tjqaw1MZFQETdImoGqVsktKR5RE5ca2srfn5+AHzzzTcsWLAAFxcXTjnlFDW0WDJ0amtrLZ7/bOshqz+r+Fid1e/ZG6NR4ac9x+jsMvTJpRZ/7St9ttYajaxefo7mnExH0l8bH+2IqFtEzSB1i4SWNI/IiYuLi+Pjjz/myJEjfP3115x//vkAHD9+3GZpOEYzbm4jmtUeEXVNHTZ/xrTrX2Xpo1+RuHgdU6971ebP0zqDOXKjdcS1J/Zs41pCRN0iagapWyS0pHlElqxcuZJrr72WP/zhD/zqV7/i1FNPBbpH5WbNmmVVAx2FPZP9QnewiCnJZE8yMzOdLtmvNTglPoA75kdZTORpSupoz0Se5eXl1NXVjTiR5xt3JXLDvy2v//ts6yGuSffRZCJPayX7TUpKEjLZr5+fn3DJfvV6vZDJfv38/IRM9puYmChcsl/AuZP9QveeqRUVFcycORMXl+4BvR07duDv78+0adNGcktNYo9kvzt27FDz6z28bhtvfrNPvWaN6bje0ZS2muJTFIXfPfUd32SVDutzkSE+gI5bL00eMCp0tLBjxw6uf8p8h4oHbshg8fxEB1lke3q2cZEQUbeImkHqFglbax6O3zHiMcGIiAgiIiJU73zChAnCVaS16OlHnz9nopkT50x8/EPxsB04EdeMKYrCK/dewE3/+Fo99/c3M0e1EyfqPrMi6hZRM0jdIqElzSNaE2c0Gnn44YcJCAhg4sSJTJw4kcDAQP72t79hNA59I3FJN6GhoerPHmNcHWjJyKltbOPe57cMqaynuyv7376Jr/5vaCk9RhuhoaHMTYo0O6ehPsEm9GzjIiGibhE1g9QtElrSPKKRuAceeIBXXnmFxx57jNNOOw2AH3/8kYceeoj29nYeeeQRqxo52um564V7DydujNuIfGy7YTAa0aFj2vUDBy+4uuiYf8qkPnnYeuoWiaCgIFxdtF231kbkuhYNETWD1C0SWtI8om+S1157jZdffpnbbruN5ORkkpOTuf3223nppZdYt26dlU0c/ZgWDYO5E2cwaHN4praxjfhrXyHh+rWDOnCrls2l4M1lFhPp9tQtEv3pHiypsTMj61ocRNQMUrdIaEnziEbiamtrLQYvTJs2TVP5U5yRntOpRo3MsXV06impbARgjKsL8//04aCfuXju5CHvgCDpZtXabUIEdkgkEonEOozIiZs5cybPPvtsn90Znn32WZKTk61imEhMnTpV/dnd7RcnLmlyiCPMMWPr3nKW/P3LIZdPmhzCh/932ZDK9tQtEibdW569hjOXv6ue14rTbgtEr2uREFEzSN0ioSXNI3LiHn/8cS666CI2bdqk5ojbtm0bR44c4YsvvrCqgSJQW1tLYGAgYD6d+v7DlzrIom4UReG2f24c1meG6sCBuW6RMOmOCPZxtCl2Q/S6FgkRNYPULRJa0jyiNXFnnXUW+/fv54orrqC+vp76+noWLFhAfn4+b7zxhrVtHPVUV1erP7uP+aVKXv863+62FB2upaahjc4uA1Ove5W2Dv2QPrdq2dxhpwvpqVskRNQtomYQU7eImkHqFgktaR5xnrjIyMg+Uai7d+/mlVde4cUXXzxpwxyNPXdsqKurU3dsqK37JVPzE2/vICG4xW47NngGRvL/Vn0zrPc0KcyThxfFkZGRMOxs7HV1dRQWFjpFNvaT3bGhZzb2uro6MjMziY+Px91NR6f+l2nUx179mnNmBGomG7u1dmwwGo1C7tjQ3Nws3I4NpvYt2o4NLS0tQu7YoCiKcDs21NXVOf+ODZbYvXs3qampquMzGrDHjg296bnDwskmw1UUpc/+pf3d8/3vi3jgpR+HdF83Vx373lh2UrZJ4MPN+7n/xR/U48gQXzY/vdCBFkkkEonEkQzH7xArWZVGMf111xtrJP4dzEU3GhW25Zcz7w/rh+zA7X/7Jqs4cP3pHu301H3p6XFm1269dHQGBsm6FgcRNYPULRJa0jzi6VSJ9eg9cunp7kp7p4GbLzn5L/TBIh4/+bF4yDstWHuLrNE0YjsceurundD5g837R2WaEVnX4iCiZpC6RUJLmoflxC1YsGDA6/X19Sdji7CMHTvW7Fin0wEw/5RJJ33vwZy4z7cdGvQeLjoofMv6e5z21i0KA+nec6jGjpbYD1nX4iCiZpC6RUJLmoflxAUEBAx6/cYbbzwpg0QkLCzM7NgUEXrPf/7HhkcuP6l7K8aBnTg314Fn1G25QX1v3aIwkG5XF50dLbEfsq7FQUTNIHWLhJY0D8uJW7t2ra3sEJqCggIyMjL6nM8vOXHS97bkwyUvWUfeuiUAfLfzsMXP2Wr0rSf96R7tDKTbOIjT7azIuhYHETWD1C0SWtIs18SNcixNp7Z3GujSG5lxo2Wn3Jajb5KBGZ0unEQikUhsgYxO1QBTpkyx2b37m07VggNnS91aRkTdImoGMXWLqBmkbpHQkmbpxGmAxsZGi+d9PE9+oHQ4+3HaewSuP92jnd66Vy2b6yBL7Iesa3EQUTNI3SKhJc3SidMApozmvRnjdvJ54oa6xMoRU6j96R7t9NadMNE80umdTQX2NMcuyLoWBxE1g9QtElrSLJ04DdPQ0nHS9+jsGjyfjVwD51hcekWkrn4vx0GWSCQSicSZsOq2W6MRR267pQOKTtLBeuKdHbz03z39XpcOnOMxGhWmXf/L1mhuri7se2OpAy2SSCQSiaOQ2245GaYNgXujs0LOMC07cP3pHu301t17JE5vMNrTHLsg61ocRNQMUrdIaEmzTDGiAbq6uiyeP9mcYZY+72jHrSf96R7tiKhbRM0gpm4RNYPULRJa0iyduH5Ys2YNa9asUfdIy87OxsfHh9TUVAoKCmhra8PPz49JkyaRl5cHwMSJEzEajRw5cgSAlJQUiouLaW5uxsfHh/j4eHbt2gXAhAkTcHV1paysjObmZtra2igtLaWxsREXXXdAgqsLZGZmEhkZiaenJ4cOdW+RlZiYyNGjR6mvr8fd3Z2UlBR27NgBQEREBL6+vhQXF/NTYb1FbTt27EBRFEJDQwkKCmL//v0ATJ06ldraWqqrq3FxcWHOnDlkZ2djMBgYO3YsYWFhFBR0L7qfMmUKjY2N6gLPjIwMdu7cSVdXF0FBQURGRpKfnw9AbGwsra2tVFRUAJCWlsbevXtpbm6msLCQ6Oho9uzpHjGMiYlBr9dz9OhRAFJTUyksLKS1tRVfX19iY2PZvXs3ANHR0QAcPtydsHjmzJkcPHiQ5uZmvL29mTZtmvoX04QJE3Bzc6O0tBSApKQkDh8+TENDA56eniQmJqqbGo8bNw5vb28OHjwIwIwZMygvL6euro4xY8aQmppKZmYmAOHh4fj7+3PgwAEAEhISOH78OCdOnMDV1ZW0tDSysrIwGo2EhoYSHBxMc3MzmZmZxMfHU1dXR3V1dZ86OnDgAOHh4er7jouLo7m5mcrKSgDS09PJzc2ls7OTwMBAJkyYwN69ewGYPHky7e3tlJeXAzB79mzy8/Npb2/H39+fmJgYszZrMBjU9z1r1iz2799PS0sLvr6+xMXFkZubC0BUVBQuLi6UlZUBkJycTElJCU1NTXh5eZGQkKC+7/Hjx+Pu7k5JSQkAPj4+FBUVUV9fj4eHB8nJyWRlZalt1sfHR33f06dPp7Kyktra2j7vOywsjICAAPV9T5s2jZqaGmpqatQ2a3rfISEhhISEUFhYqLbZhoYGjh8/3qfNBgcHExERwb59+9Q229LSor7vOXPmkJeXR0dHB4GBgURFRaltdtKkSXR2dnLs2DG1zZr6iK6uLtra2k66jzC9b1Mf4enpyYwZM8jJ6V4/OdI+wtRmq6qqqK2txc3NjdmzZ59UH2Fq39boI9rb2wkICHCKPsJgMJCZmWmVPqKoqAjArI/Q6XSkp6eTk5ODXq8nODhYE32En5+fqvVk+oikpCSOHDniFH1Ec3MzNTU1VukjLPkRDQ0NDBW5Jm4Q7LEmrrm5GV9fX/X48j9vYF9pLRedOol//e5XI77vxfd+xP4jdepxgI87WS/dcFK2WpPeukXBkm7TOkgTWhoxtQayrsVBRM0gdYuErTXLNXFOhumvURMTw7v3qE2NjxjR/Tr1Bs69a72ZAwfQ+vOerFqht25RsKR7eox2NlS2BbKuxUFEzSB1i4SWNEsnToMcPt6dSDC7sHJEn/9+5xGOVDf1Od+lH30L5kcLd16Z6mgTJBKJROJkyDVxGiA2Ntbs2LTx/dc7SoZ1n9v+uZHswkraOwfPDacFeusWBUu6w4K8zY6LDtcyNTrYXibZHFnX4iCiZpC6RUJLmuVInAZobW21eH44wanlNc18m3OYhpZOOoaQ4FcL9Kd7tGNJd+LkELPjS+7bMKp2bpB1LQ4iagapWyS0pFk6cRrAFJF1MrS2Dx7yrLU9Oq2h2xkZqu7RtHODrGtxEFEzSN0ioSXN0olzYm567CsSF6/ljW/yGWzQztPdlUXzEuxil8Q6dOqdY0RVIpFIJI5BronTAGlpaSP63A953Tlnnv1wF6dMj+y3nFbTVYxUt7PTn25Pd1ez9YyjKfmPrGtxEFEzSN0ioSXNciROA5gSMI6Uc2ZFn/TuDo7gZHU7K/3pvvyMKWbHupPfdU0zyLoWBxE1g9QtElrSLJ04DdDe3t7vtaEsbv9f7hEuu3+DxWturtr1BAbSPZrpT/dFp042O3Z3c7WHOXZB1rU4iKgZpG6R0JJm6cRpgICAgH6vDWVx+4nG9n4jWfe9sWykZtmcgXSPZvrTnTF9nNnxaUn9T5E7G7KuxUFEzSB1i4SWNEsnTgOY9vezRGvHyDfavXju5MELOZCBdI9mhqp7U3aZjS2xH7KuxUFEzSB1i4SWNEsnTgOYNsm1hN4w8l0WVi8/Z8SftQcD6R7NDFV35yjaYUPWtTiIqBmkbpHQkmbpxGkQX68x6s+910kNFU/30bOeSiKRSCQSSV+kE6cBYmJizI7PSY0CYGp0UJ/RtJqGNmob2zj7zncHvGfeuiXWNNEm9NYtCiLqFlEziKlbRM0gdYuEljRLJ04D6PV6s2NTzrfIsb5m54uP1jH3trc55bdvU17TYjf7bEVv3aIwkO5TZ4zr95ozI+taHIar+b7nt5Bx65tOv82ciHUNYurWkmbpxGmAo0ePmh2b9j79ftcRVjz7vXr+v1sPDnqvVcvmaja5b2966xaFgXQ/98fz1J9HU544WdfiMBzNtY3tfLTlAHVNHTz/yW4bWmV7RKxrEFO3ljRLJ06DrP3il0SCn209pP48lAz+cmut0YOiDC1PoETirLT02PN52UVJDrREInFOpBOnAVJTU82OTzS0WSx3tLppwPtEhvhYzSZ70Fu3KAyk28dzjNnxqrXbbG2OXZB1LQ7D0dxztPnqc6bawBr7IWJdg5i6taR51DtxR44c4eyzz2b69OkkJyfz/vvvO9qkPhQWFpodu7j80rN5eXRvb5tfUmM2KmeJzU9fY33jbEhv3aIwHN3GUbKBqqxrcRBRM0jdIqElzW6ONsDWuLm58dRTT5GSkkJlZSWzZ8/m17/+NT4+2hm1am1tNTt27eHE3XddOgBfbB/YgXNGeusWBRF1i6gZxNQtomaQukVCS5pH/UjcuHHjSElJASAiIoKQkBBqa2sda1QvfH3No1AbWjrVnz/YvB+AT34cPKjB2eitWxQG0+3nPWbA686IrGtxGKlmxclHnUWsaxBTt5Y0O9yJ27JlC5dccgmRkZHodDo+/vjjPmXWrFlDTEwMnp6eZGRksGPHjhE9KycnB4PBQFRU1ElabV1iY2P7vbbnUA25B45zvG5gz7/nFKyzMJDu0cxgul9/4Ndmx5fdv8GW5tgFWdfiMBzNztdr9Y+IdQ1i6taSZoc7cS0tLcycOZM1a9ZYvP7ee++xYsUKHnzwQXbu3MnMmTO54IILOH78uFomJSWFxMTEPv/Ky8vVMrW1tdx44428+OKLNtc0XHbvNg+tv6lHlJanuyslFQ2D3uPBJada3S5b01u3KAymO2HiWLPjgjJtjRyPBFnX4jBSzfe9sMXKltgXEesaxNStJc0OXxM3f/585s+f3+/11atXc/PNN7N06VIAnn/+eT7//HNeffVV7rvvPgByc3MHfEZHRweXX3459913H3Pnzh20bEdHh3rc2Ng4RCXW4/Iz4njl8+692do7DUP6jEwtMnpwxlFViWQk6HqEp36TVeZASyQS58ThTtxAdHZ2kpOTw/3336+ec3FxYd68eWzbNrTUC4qisGTJEn71q19xww03DFr+0UcfZdWqVX3OZ2dn4+PjQ2pqKgUFBbS1teHn58ekSZPIy8sDYOLEiRiNRo4cOQJ0jxAWFxfT3NyMj48P8fHx7Nq1C4AJEybg6upKWVkZ7e3ttLW1UVpaSmNjIyd6bMYw1s9tSGv4MjMzge51f76+vhQXFwOQkJBAVVUVtbW1uLm5MXv2bHbs2IGiKISGhhIUFMT+/d3r7qZOnUptbS3V1dW4uLgwZ84csrOzMRgMjB07lrCwMAoKuvOWTZkyhcbGRqqqqgDIyMhg586ddHV1ERQURGRkJPn5+UD30HNraysVFRUApKWlsXfvXtrb2yksLCQ6OlrdUDgmJga9Xq8mU0xNTaWwsJDW1lZ8fX2JjY1V/wqKjo4G4PDhwwDMnDmTgwcP0tzcjLe3N9OmTWPnzp3q+3Zzc6O0tBSApKQkDh8+TENDA56eniQmJpKdnQ10r6P09vbm4MHudYgzZsygvLycuro6xowZQ2pqqvq+w8PD8ff358CBA+r7Pn78OCdOnMDV1ZW0tDSysrIwGo2EhoYSHBxMe3s7mZmZxMfHU1dXR3V1NTqdjvT0dHJycixmA7/4nvdZc0c6lZWVAKSnp5Obm0tnZyeBgYFMmDCBvXu78wtOnjyZ9vZ2dSR69uzZ5Ofn097ejr+/PzExMWZt1mAwqO971qxZ7N+/n5aWFnx9fYmLi1P/SIqKisLFxYXC/YeoqO/ksnnplJaW0tTUhJeXFwkJCer7Hj9+PO7u7pSUlKjvqaioiPr6ejw8PEhOTiYrK0ttsz4+Pur7nj59OpWVldTW1vZ532FhYQQEBKjve9q0adTU1FBTU6O2WdP7DgkJISQkRI0kmzJlCg0NDeoofs82GxwcTEREBPv27VPbbEtLi/q+58yZQ15eHh0dHQQGBhIVFaW22UmTJtHZ2cmxY8fUNmvqI3Q6HW1tbSfdRwAkJyerfYSnpyczZswgJycHgMjISDw9PTl0qDsAKjExkaNHj1JfX4+7uzspKSnqMhRb9xGm9j2UPsKkFeD0GaFqPffsIwICApyij3BzcyMzM9MqfURRURHAgH1EcHAw4eHhap8cFxdHc3Oz3fuIcePGqVpNfUTPNltSUjKkPiIpKYkjR444RR/R3t5OTU2NVfoIS35EQ8Pgs28mdIqGVpPqdDo2bNjA5ZdfDkB5eTnjx49n69atnHrqL9OF99xzD//73//UShuIH3/8kTPPPJPk5GT13BtvvEFSkuXEkpZG4qKiomhoaMDf33+EygamoqKCceN+2W7pyPEmzv39eqA7xchDS+dy7/P9TzU4yw4NvemtWxSGonvNR7v49wc7zc5Njwnm479fYUvThkTGrW9S19TBNb+aysO/OX1In5F1LQ7D0Xysuolz7uru63JfvRFvT+cN6hGxrkFM3bbW3NjYSEBAwJD8DoevibM1p59+OkajkdzcXPVffw4cgIeHB/7+/mb/bI3pr0QTPVOMtHUMvEebp7urTWyyB711i8JQdN9y6cw+5/aVamNtXF1T9x85w4mYlnUtDiPVrJnRhBEiYl2DmLq1pFnTTlxISAiurq7qcLyJqqoqIiIiHGSV7ensGto6OIC8dUtsZ4jEYYxx6/ur6eOlrVGK2PGBjjZBIpFIhEbTa+Lc3d2ZPXs23377rTrFajQa+fbbb1m+fLlNn71mzRrWrFmDwdDtUNlyTZzBYDBbE9elmFfLtzuK+7Vz165ddl3vYs01cQaDQcg1cQaDYdA1ccHBwX3qurWtS32uI9fEmahpaGXfvn1DWu8SHx8v5Jo4Ly8v4dbEmdr3kNbE9QhKO3rkCPW11YBzronz8/MTck1cQkKCcGviDAaDXBNnorm5We1MZs2axerVqznnnHMIDg4mOjqa9957j8WLF/PCCy+Qnp7OU089xfr16yksLCQ8PNzm9g1nbnqk5OfnM2PGDLNz8de+ov7s5qpDb+hbTQE+7mS9NHiwhlaxpFsEhqr7lc/y+MfbWeqxq4uOgjeX2dK0IdGzbQ51Paasa3EYjuaj1U386uc1cbtevbHP3sHOhIh1DWLqtrVmp1oTl52dzaxZs5g1axYAK1asYNasWaxcuRKAhQsX8uSTT7Jy5UpSUlLIzc3lq6++sosDZy+am5sHvG7JgQOc2oGDwXWPVoaq+6aLk82ODUaFK//yiS1MGhGRY4e+dZ2sa3EYjuaeQwgairEbESLWNYipW0uaHT6devbZZw/6y7t8+XKbT586Em9vb0eb4BCk7uGz51CNFS05Obw9x3DBHz+gqbWT3105a8BchbKuxUFEzSB1i4SWNDvcidMq9lwTZzQazdbEeXp6DsnGzMxMu693seaauNbWViHXxDU3Nw9pTVx4eDjBvm7UNptHKGdmZlplvcuRmnZe3FRJY2snl6aFcG5y8LDWxBUfq1d/fvaDbK46K67f9S7Tpk0Tck2cj4+PcGviTO17KH1Ez0TtR48epcGJ18SJmidu8uTJwq2JMxqNck2cs2CPNXGZmZlkZGSYneu57sgSkSE+bH76GpvYYy8s6RaB4eru3Rb+uuQUbjj/5NdjnPrbtzjR2A5AZIgvm59eOCJ7AFYtmzvgSJysa3EYjuaeOTF3vnIDvl7utjTNpohY1yCmbltrdqo1cZKR4ewOnGTk/G3ddlY8+/1J38fkwAHcemnyACUHR277JhkJPccQPvh+vwMtkUicE+nEaYAJEyb0OecxZvRXjSXdIjBc3X7efSP2Ptt6iE790PMJ2pp3NhX0OacoCvPv/pDkJevYecToAKscj4htfKSa136518qW2BcR6xrE1K0lzaPfU3AC3Nz6Lk2cGRfmAEvsiyXdIjBc3Tkv32jx/Km/fcsa5gCw+r2cQcts3nWEbfnlFq+98Glen3MlFQ0cLK+nvdPAy19pa5TFYDTyh2e+47Tb37bogFoLEdv4cDTrdL/sTrN0fqItzLEbItY1iKlbS5q1Y4nGsGdgQ11dHWeddZbZouW5iZHsKKi0aNv0iUHqInFnDmyoqKggJibGKRYtWzOwYdeuXQQFBQ1r0bIlmlq7k/+ezKJlE3p99736C2xoatNz+4uF/dpy0eyxdHV1mS1arq3rVK8bDUaKioqoq6vD09PT4YuWt+TX8fm27kXGT7+fRdpEN5sENjQ1NeHv7y9UYIOpfQ83sCEjzlutZ2cMbNi7dy+lpaXCBTY0Njaq70yUwIa6ujrS0tJkYIMz4KjAhvrmDtJvedNieWfd8L43Ii6IhZHpPu32d6iub+1z/mTaQs8AhcECEw6V13Ph3R/2e93S50srGjj/jx8AMHdqAO5e/vxv9xHuuGIWd16VqpY70dCGUVEIDbRf2P7D67bx5jf7+rXdWojYxoej+XBVI/P+8D4A91+fwdJfO+9onIh1DWLqloENEjOSkpL6nAv09XCAJfbFkm4RGInun55b1Odcj5mok2bRvARuX72RmUtf4+2NfUf/dIM8rPd0anun3uwzW4sa2Jx7BEWBl/77S1m9wcipt73Nabe/wxtf5Z+kipGRVWh5xNsaiNjGR6p5nZOviROxrkFM3VrSLJ04DWAa6h8Kq5bNtaEl9mU4ukcT1tLdO9v9dQ9/zuzfvMGjb27HYBx+IMGm7MO0dej5+xvbh/3ZntGtWQUVJC95jUvu+8hi2agwP/Xnto5fcuD97fXtTL3uFatE3g6HL7eX2OzeIrbxkWpe4uRr4kSsaxBTt5Y0SydOA/Q3/+3Sa/AjMsRnVKVyGM68/2hipLovzIjpc87k8Ozcf5yswkqaWjtZ+0U+59z53ojt69IP3wHs2S4ffat77VV7p+Xo2XEhvurPu4uPm11TFNs4VQ+89CMpS1/jjW/6jvbNP2WS1Z9nQsQ2PlLNzr6yR8S6BjF1a0mzdOI0QH87NBS+dRMXz52Mq4uOi+dOHnW54Ya6M8VoY6S6n77r3D7nPtt6iHc2FdDS3mV2vrK27/q5oRId7tfn3GAzt2YRnoN8F+f8PH1pMBpZ9tjXfa7bwql6//siWjv0PPF2Vp9rq5efY/XnmRCxjY9U8wuf7rayJfZFxLoGMXVrSbOMTu0He0anKorSZ9stU+TZwjne/OGy8/D09FSjbxwZeWbN6NS2tjYht91qa2sb8rZbvSPPLPH3N7bz5G2n9zlvsnEo0al/f+lL9ee6pm77ekanVtZ1WHy2iZ7bbrW0DLw5dGuHnrc37iPa2/Jfs0kxgRw4cGBYkWf7Dpbz2Ecl6BUXLp8Twq+SgvpEngF0dBrIzMykquqXdXCZmZk223bL19dXuG23TO17uNGps2KDnDo61dRHixadGh8fL9y2W4qiyG23nAVHRaeKgNQ9fObc/AYNLZ2DF2TwqMue0amRIb6U13Q7X3OmhfPWyovNyvaMNB3sWVf8+WPyS08MaFtkiC/f/PMqEhevs3htqFuAmVj26Ff8uKe7gwzw8SDrpevNrvfUuv/tm8yiU/e/fRMVJ5q58i+f0KU3smJhmtWWLQy1ro1GhY3ZZSRNDiGyx3SzMzKc9l1W2ch5K7qjUx++6TSuOXeaLU2zKbI/EwcZnSqRSEZE1ks3DLmspQS8/ZEa/0ty6azCKubc/MawkuB2dBnQG4w0tnQMNpsKwFh/z37XQI1kC7D6loFHCnvT+9lPvJNFTUM7DS2dQ0p8bG3+u/Ugv3vqW86+8z27B3ZoBTmeIJEMH+nEaYBx48Y52gSHIHWPjLNSooZUbjjOUE5Rldlxb2dmsK/Xv7+RyfQb1pJ285uUVTYO+ry9JTX93vNkR8FWLJw97M90dNlmC7Oh1vW2vb/shGHLaFl7MNL27ew+nOzPxEFLmqUTpwG8ve2X5FRLSN0j46V7zh9SuRc+GfpC8VlT+m7zNlLHpneQhSUUBe7693cWr9lyG6z+cOmR024kTmB/jKSubRktaw+Go7nn6JsypDFc7SL7M3HQkmbpxGkA02JN0ZC6R07c+MBBy5SfaBny/XYdON7nnE4HT76bRepNr/PR/6y/9+nm3CMWzw80DfzQqz+ReONanv1ol9n5ntGzw5lGBnjjq3yzz1szjc9Q67pnLuU50yKs9nxHMNL2/cFmbe2vO1xkfyYOWtIsnTiJxAn54okrrXq/lLi+I3FtHXpe/DSP5rauYTtGQyF+QpDF8wNNA7+9qZBOvZGnP9jJjoIKi2WWzJ9B/LWvEH/tK5xxxzuk/eaNAe342+vb2V1crR47YiSwJ7Z411ql564e+wYJhpFIJH2RKUb6wZ4pRvR6fb8pRkBb6QOsmWJEr9cLmWJEr9ePOMVIz/QBQyEzM3NIKUa25JYN6X7W5FBFvcXzF2dM6DfFSE+u/9sXrLhqJufNDKKl5ZdRx3+9+0suuKq6vvnyUpauJTrEPM9TRe0vn3/8re2clxI6aPqAz7/L4p//LUVRXFh8Xhyzo7v/Ju7ZR3h4eAyaYuSB5zfxaeYvKU/OT/InMzPTaVOMmNr3cFOMxI3zc+oUIz4+PkKmGJkyZYpwKUb0er1MMeIs2CPFyP79+4mPj7fJvbWM1H3y9EydYfFZb9804s86kovnTraYhLe3zZ5jXPnxuUWk3fzmsO7v7uZCZ4+dKTzdXdUdJiylKLHEr+56j6PVzQN+pqCgiISEqUD3Zu9+3u4E+Zk7kL01DZYaRusMp333TF0zP2MS/77rV7Y0zabI/kwcbK1ZphhxMurq6hxtgkOQuk+eva8vsdq9tMRnWw8NqVyH3sD9L/4w/Af02oKi5xZhPQMbymuaueKBj5l729t9pllNDhxAS3vf3H37Sk+w4JEfmf2bN3jh093M+8P7ZNz61qDTtc4+nTrS9v39LtvsR2kwGtm5v4qOTv3ghU8C2Z+Jg5Y0SydOA4wZM8bRJjgEqfvkcXdzZYxb/7/GZ9/5rtWepUUUBTZmDX8qOHJs/wl1e46C3fbPjeSXnKCmoY2/vba9388YDH0nNB55fTsGIzS1dvKfj3PV80++mzVgTrSR5MnTElr7vX7x0zyueegz0m5+06brHbWm216IqFtLmqUTpwFSU1MdbYJDkLqtQ/7rS/u9Vl5jOUL1yXf77iGqNTr1BvIOVtNpgxxu1fVt/V6bdt2rPLehe+1qQVmtel5vMPb3Ec6dHd3nnKvLL8N9HT1G+ppau5h23atWSeqrxdUww2nfPa0f6I+Rk+H1r7rX5nZ0GWw6yin7M3HQkmbpxGkAS4u2RUDqth7npU0cVvkXnWDKLvHGdVz1109JXLyOG/7vCzr11nPmwgK9+r1mVBSeen8n9zz3P7PzAT7uGIxGSioa+jhPK5fM7XMfV9dfnDiD0by8Qv9JfYfiaCiKwo95R0lZ9jopy15zeERtT0bavptaB88veLLYcpRT9mfioCXN0omTSEYBa1bM6/da8pJ19jPERmTuqyB12etWu9+xmuZBy3z8Y7HZcWNrJ4se+owL/vgBf3jGfBTtk15lAVxdfule3S2MMvWX1HcojsaGLQdY9tjXtHXoaW3XO2SrMGtg75HErMLKwQtJJE6EdOI0QHh4uKNNcAhSt3VJmhxi8XzPRfvOTM9o0pNGN3iR3niMcSP353xyX/QaRfvX+uwBP3vVOX0j2fpL6juUyNQNP/R1GrWCln+vLY1+KopiFWdSy7ptiYi6taRZ5onrB3vmievs7CQ8PNwpckBZM09cY2Mj7e3tTpEDypp54o4cOUJVVZXVc0Ddc0kE3+a581VuHZV15mu+du/ebZYDSnQ6u4bvEEYGuXOoynKEo1GBzVu28uKmYxw+0cVl6WHU19er18d59V2D9/hbmSw4Y3Kf86a21V+euPqGBjL3mSc6DvYbw4GSo5yoOopOp7NpH3HixAmznFu9+whT+x5KH7F7t/nUsS3yxCnKL3V9/pxo8vPz1T5i6tSpnP+H9zje2MXySxO48szJI+4jGhoahMwTFxISIlyeuM7OTnx9fWWeOGfAHnniMjMzycjIsMm9tYzUbTvSfvM6jT3WGPXOPablHHFaRaf7ZZP2MW4udPUaGVxw5hQ+2tL9xRHg48HMuFC27O7+4kuODSXvYLVZeVNuuaHkiTMaFSprW4gM8eX1r/P5v34iZb083LjvuvQBR/O69Eb2lZ4gcfJYsynf3ty9ZjM/5h3jrv+Xqt7vxke+YHt+twPp7eHGvRaeNZz2XXysnl//6UP1eKC8hiNl7m1vU9PQZvH+Dc0dzLmlO8dgeLA3Pzy7aMTPkf2ZONhas8wTJ5EITvbLN5odP/jqVgdZMnro+edubwdOp4Ote8vV4069wSzysrcDB+b56Hryt9e29Tl37u/Xc/ad77HwwU958+t9/drY1mF5fdy9z/+P5CXreO3LvVz/t8/4fys/5ebHv+n3PoerGvn0p4PUNrWbBVqYHDiA1g59nyAMRVFo6RjG9L2GxhCuP2+6o02QSIaNdOI0QEKC82ZnPxmkbtvi4mK+8OvKv3wCYJOUHaIT4O1ullJkKL7JJXNj2fDzyF1P9AalT/oRUyDGrgPVg967o1f9nveH9WzYUkx7p4En381m14Fuh/LHvGP93qOuqUP9OTW+7766JnoHYVz+54/57fMF/NMJUtj05rIz4k7q87I/EwctaZZOnAYwzb+LhtRtW/J77eaw51ANp93+Nm9rKB3FaKG+pZMee7mb/dwfZ935Lvc+v8Xitc+39b9jhcLAXpybq47ymmZ+2H2UytoWyqqa1Gu9HbzL/7yB027vuxtFT37Y3b+z1zuPnymv3usDjBZqiZ719PDakxutlv2ZOGhJs3TiNMCJEyccbYJDkLpti6X1TtX1bfz9De3kOBotjA/xMRshc3dzHXSbp4HyornqdP1GTB7u4ZRZormti7PvfI+b/vE1tz7R/5QpwL7SWqrr28ymRbv0Rh589ccBP2fiw/8d4PbVm/qcb+vQk/abNwbNX2fLydSKE8193mFve3pe/jbn5Lb9kv2ZOGhJs3TiNICrq6ujTXAIUrftiRzrY7dniUxvZ+SMmePZub9qxPfTGxWmXvcqZ//uXSpPWN51Yyj03HFiIHpOmZ77+/fYV/rL5/5wteW1eya27D7KtOtf5dXP95idb2zt5J/vDpx6xVZxdeu/L+Ks371H6k2v097DmR4on17v5QfDRfZn4qAlzdKJ0wBpaWmONsEhSN22Z/Mz19jtWSLTe3uzL7YdwmiFtHblJ1o483e23/925/5fpocqa1vNrhmH4GgZjQr/Wt/XQRpsl42Glk71Z48x5l+MVXUtnHHHO8wZwZ6nT77TvSavpV1Pc1v/I549laUnjBvWM3oj+zNx0JJm6cRpAFMuHNGQuu2Dp7t2/moUBaMyNOdHKwy0S8RT7498N4ie+6Eaenm1lbUtXPfw5+pxR5eB9Ft+cdieWp9DVV0rDS0drH4vh+N1rbz7bSGt7SPfnqt3RHDPkcCEmOAR3xdkfyYSWtIsnTgNYLTGn+xOiNRtH/LWLbHr80RlsIADLTNQXrmOTgNvbyzg0vs2DHiPUAv70d5y6UwAjlY3kbL0NVJvel110jbvOtKnfH1zhzrl2drxyzRop97ApfdvYOUrP3HFA58MqkfXI2LBY8wvX3O9dfZ04gqHOPXcH7b+va440cxpt789opFJWyJiP64lzdKJ0wChoaGONsEhSN32wxZJVCXm9J5SdSb+8Mx3/V7r6DLw0NqtFB4e2Mk5Wm1hP9qffaRnP9xFR5eR5rYu/vluNiUVQ8hI3ytQpLaxHYCSioY+KVgA/vziFuKvfYWE61+ltf2XadreOf3MHtHjGTlFI1/DCLb/vf7HWzuorm+joaWjT34+RyJiP64lzXLbrX6w57ZbXV1djBs3Trhtt5qamujs7BRu263y8nKqq6vtvqWOtXnghjmcqKnhje/K0Ol0rFiYxjfbCtm+f+hbxki0wefbSvg68xV8PK37lfDMhzmkRLRTdfyXZMeNrZ1c8McPOC3R8t6xwT46DAYDdT22LRsf4k1Dyy+56z7beoiFc7wxKgpps2ezZ+8+Ptjc/TtoMCoYjL94Zz1+pL293ayPCBk3sUc5I1VVVSPuI5qbm2267VbF8RrV1vOTuu+nhW23wsLChNt2q6urC39/f7ntljMgt92yHVK3/bHmdluDje6tePZ7Ptvaf74ziRiMcXMhITqYvEM1ZudddOYOVk+K3lrGggc+Ib+0/1QOseMDOXisnl+lRnPLpclc89Bng9ry18WnsPDcabi7da8T7bnt16RxAXz9z6sG/HxtYzvlNc0kTg7pc83Wv9e3PPGNOgV98dzJrF5+js2eNRxE7MfltlsSicQh2HNadfXyc9j/9k3sf/smlpwTiaf78Ed5dDpImHhyC84ljqVLb2RPSU2f8/05cABTr3t1QAcO4OCxegC+23mYLblHh2TL317bTtLidby9sYAL7/7AbN/W43WtA3yym1N/+xYL/vIJS//+5ZCeZ016riH8cnuJ3Z8v0SZyOlUDxMfHO9oEhyB1O4b9b9/EO5sKTmo/1SQLIxEDcdtVp/Dnm4P6nH9nUwGPvrmDjq7uRey95wWK3rpJLSf3f3VeLM33DDQSN1ye+zh3WLasXp9NY4/0JgBhgV6s+WgXL3+2h5svSeL2K2b1/ezP//fcJ9eEPX+v558yyW7PGgxH92eOQEuapROnAerq6ggK6vsFN9qRuh3HonkJA0YkWpv+NFuyI3nJOto7DWapUXqXk1O1zo/W1vEcr2/j6Q93oijw9Ac7LTpxJhRg5tLXuO+6dLVd9m7jiqLQ1NqJv4+H1W2dM83yekJHoIX+zN5oSbOcTtUA1dXVgxcahUjd4jAczXnrlrD/7ZsGTI2yevk5rFo2l8gQX7NRQZkTz3lw5Grs312Z2udc0uQQ1Saj0j36+8Q7Wcy5+Q3e3tg3pUdbh94sSrR3G7/moc9Iu/lNHn1zu3WNB01Fp8r+zLHIkTgNoBvKbtmjEKlbHGyhub/RxHc2FfDCp3nU1LfSOUB6CYm4VNf3Xf8WHxXE9n0V6vE/3tqh5qp7/O0dXHte37Y21t8TgKyCCg5WttFzqfuuA90Rjq9/tY/7rz/FitYPnJzZ3sj+zLHI6NRBsEd0qkQisS1X/Plj8ktPoNP9MgJ08dzJAHy+7ZB6zmOMCx1d0vETEXc3l36dfo8xrkybGMzu4r4jMFecGceGLd0pnX59yiSeuvNXwC+R4G4uOva9ueyk7esZWb5q2Vy7LoeQ2Jfh+B3SiRsEezhxOTk5zJ498CbToxGpWxycSbNpJO/WS5N58t0smlpHvs2TZHQQNz6Q4p+jYQfDFAFucrp6py4xfeWaRnMUReHOp74jq6iSu65KteicKYrC1OteVY8DfDzIeun6EWmxNs70u20tbK1ZphhxMvR6/eCFRiFStzg4k+ZF8xLY/PRCFs1LIOflG9n/9k1m6+9cXXS4/Dyb4uYy+LTKxXMnq6lWLp47mZ4zMQE+7lax2XT/VcvmEuDjQYCPB6uWzWX/2zep6wTlesGR058D5z5m8K/QkooGdZssRVG48O4PSPw5zQlA5r4Kvs4qpbaxvc9at4oTzZxxxzuk3/qW2fmOLsMIVNgGZ/rdthZa0izXxGmA4GAx82BJ3eLg7JoHiubtL1LW092Vm86fxF3XnqWeW738nH6TtL6zqYDV7+X8/AWt0N45/C9qS3b2DBA5+853nXp7MK0xPsSXkopGs3OKovRZM/XCp3ksmpdAR5dBLf/P97K59rwEKk78Uh+p8WFmn3vszR1UDSF/nSNx9t/tkaAlzdKJ0wDh4eGONsEhSN3iMJo1D+SYNTY2WjxviaGkfbnigY/JLxk4Ce5AbH76mj7nZLqWkVPay4GbHBnA7N+80begorC7uJqJET2nxkzTqr+c+SqzlNXLfzlu7bA8la83OGbdpsFoxNXFfPTR0u/2n1/8gY1ZpaxYmDYq1+5pqT+T06kawLT3nWhI3eIgomawvu4Nj1yuTp1edOokXFx0aoDGSOm5s0bPf6ap3zFuLgT4eODmavuIPEt27H/7pmEnl7YXul7T6SXlDTS3ddHcZu58lZ9o4f+t/JT0W95Uz910UVKf++kNRl74dPegzzUYjWp508+D8dCrP5Gy9DVe+2pk+yjf/PjXTL9+Lc9+tMvsvKU2/sHm/TS0dLJmQ+6InqV1tNSfyZE4iUQicUL+9btf8a/f2e7+A40wgnkASFZhJZ9vO4SnuxsL54ZSUufK/4a4FdZQ+PD/LlN/1tKUsLHXlhMjiRLsPfX69Ac7ufXSmQN+5qJTJ6MoCqfc+haNrZ1mUbH98fam7k3en3wni8UXzhi2nab6fOnTPJYv6D8Rck8uOz1u2M+RDA/pxGmAuDgxG7rULQ4iaobRrbvn9O+ieQmqw3fixAnGjh0LmE/VJk0O4URjO2P9PdlzqO9eqkMdbTNNCb+zqYCHXt2quZ0fhsorn++xuCtEV480J/05wquXn0NzWyeNrd1bh32VWdqnTOWJFny8xuDnbR48YzzJhBRxEwLNj3u18U4NBV3YCi39Xksnrh/WrFnDmjVrMBi6G2R2djY+Pj6kpqZSUFBAW1sbfn5+TJo0iby87oiiiRMnYjQaOXKke6PilJQUiouLaW5uxsfHh/j4eHbt6h6KnjBhAq6urpSVldHa2sopp5xCaWkpjY2NeHp6MmPGDHJycgCIjIzE09OTQ4e6O8PExESOHj1KfX097u7upKSksGPHDgAiIiLw9fWluLg7b1FCQgJVVVXU1tbi5ubG7Nmz2bFjB4qiEBoaSlBQEPv37wdg6tSp1NbWUl1djYuLC3PmzCE7OxuDwcDYsWMJCwtTh5GnTJlCY2MjVVVVAGRkZLBz5066uroICgoiMjKS/PzuYfvY2FhaW1upqOhOpJmWlsbevXupra1l3LhxREdHs2fPHgBiYmLQ6/UcPdrdeaWmplJYWEhrayu+vr7Exsaye3f3dEN0dDQAhw8fBmDmzJkcPHiQ5uZmvL29mTZtGjt37lTft5ubG6WlpQAkJSVx+PBhGhoa8PT0JDExkezsbADGjRuHt7c3Bw8eBGDGjBmUl5dTV1fHmDFjSE1NJTMzE+heG+Hv78+BAwfU9338+HFOnDiBq6sraWlpZGVlYTQaCQ0NJTg4mLy8PLy9vYmPj6euro7q6mp0Oh3p6enk5OSg1+sJDg4mPDxcfd9xcXE0NzdTWVkJQHp6Orm5uXR2dhIYGMiECRPYu3cvAJMnT6a9vZ3y8u79HWfPnk1+fj7t7e34+/sTExNj1mYNBoP6vmfNmsX+/ftpaWnB19eXuLg4cnNzAYiKisLFxYWysjIAkpOTKSkpoampCS8vLxISEtT3PX78eNzd3Skp6d6oOygoiJqaGurr6/Hw8CA5OZmsrCy1zfr4+Kjve/r06VRWVlJbW9vnfYeFhREQEKC+72nTplFTU0NNTY3aZk3vOyQkhJCQEAoLC9U229DQwPHjx/u02eDgYCIiIti3b5/aZltaWtT3PWfOHPLy8ujo6CAwMJCoqCi1zU6aNInOzk6OHTumtllTH2E0GklJSTnpPsL0vp2hjzC17ylTpnDnxZNYOMe7z/sOCkqw2EeY6tnUR7S3txMQEGCxj5js18jrdyWa9RE/FrXwysZS9Abtu3ZNrV089spXxMTE9LmWmZlJ+wD5Cv/+0pdckD5RPY4Mcif1ptdoadOTER/AdWdG8LuXiwBYck4kV575y5R7sI8r1dXVZn3E79cWc6KxndmT/Vh5XZLFPsKEUTGyd+9etY/w9PRU21JUVBTvbi5Ry7725V4umukzpD4iKSmJI0eOOEUf0draSlJSklX6CEt+RENDQ7913xuZJ24Q7JEnLjMzk4yMjMELjjKkbnEQUTOIqVsLmufc/AYNLZ3odN3Tlb2nPbWCp7urxSjkOdMiyCqs7PdzAT4efP74Ak6/450+11x08PTvz2X5v74FwMvDjd1rF6t567w93ch9dbFavncOOlOeu57UNrZzym+705yMcXMh//Wl6rXe9X3u79dz5HgTADqgyML9nB1bt/Hh+B1yJE4ikUgko4qsl27oc06LUbj9pZEZyIEDaGnrZP6fPrB4zajA6vey1eO2DvOcZiEBXurPiqJY3MN2+b++ZcvuI9x7XTrXnTed3z/9nXptsMhYkwMHMDdp/IBlJSePdOI0QHp6uqNNcAhStziIqBnE1K1Vzb0DNd7ZVMCqtdtOeo2YI9AblQF3EjlU3v903NHjzUB3EuKr/vqJ2Ro8E99klQLwxNtZXHfedLM9ZXu/roHqu66pTc2bd/BYPcH+ngT5efZb3lnQUhuXKUY0gGnNkWhI3eIgomYQU7ezaF40L4HCt5aZ7XQxWjHtGAHdgQ03PfYVF/zxA5pauwZMKt3aMfjOBAPV977SWv703P8oqWhg/p8+JOPWt8xscVa01MalE6cBOjs7HW2CQ5C6xUFEzSCmbmfUvGheAlkvXW+WI280sfq9HLPjH/KOWe3eg9X3l5klbM8v79eWgdAbjByrbhq8oJ3RUhuX06kaIDAw0NEmOASpWxxE1Axi6h4Nmi3lyLvyL59YTI3iDDS1dlj1fu9sKlDTywxW3116I8/0ShDcG0tblQGc9bt3qa5vY860CN5aedGI7bU2WmrjciROA0yYMMHRJjgEqVscRNQMYuoerZo//L/LWLVsLp7uzjf2Ye3g3Ide3ar+7Ok3lu9yDg8YAVxT36b+7ONl/v5+/acPmXrdq9yxelOfz1X//DlToMeqtVtJXrKOVz7fc1L292tnQ9uQtjTTUhuXTpwGMOX3Eg2pWxxE1Axi6h7NmhfNSyBv3eJRO+0K3elZepJTVNWnTE937cJ7PuG3/9zI1Ss/HdL9e++2UXysHoBN2d35EBuaO3jzm33UNvZw/Dy7Hb+3NhbQ3mngXz2ib63FvtITzL3tbVKWvjbouj0ttXHpxEkkEolEMgJM+86uWjbX0aZYjYYW8/Vei1Z9NmB508BV3qGaITtyJjr1vwRVmPahvfu5zTy8bhtnLn9XvdY7TcpQRssqa1vo6Bw8MMPER//b/7NNxmGt23M0zjcuPAqZPHn0/TU3FKRucRBRM4ipW0TNi+YlMG9mCKGhoQBqYl2R8HR3Jbe4elifueFvn6s/m6ZjTVuNdfZIfTLc6eBD5fVcePeHuLroWLnkVHX9nrXQUhuXTpwGaG9vd7QJDkHqFgcRNYOYukXUDOa6TbsevLOpgAd7rB8bTfz2yW/Y3cNpGyhVSW9WPPs9u4urzRIDD8ZwHONvc7q3YjQYFV74NG9ITlzPwIoVC2cPWFZLbVxOp2oA0/51oiF1i4OImkFM3SJqBsu6F81LUNfPmaZdI0N88fce4wALrct3O49wonFkzsxnWw8Ny4HrjVGBpMXreHvj4Dnnbr00ecTP6Q8ttXHpxEkkEolEYgcWzUtg89MLyX75RtWhS5oc4mizNMNwRts6ugys2TBw6hLo3vf14M/BE0PlhU/zzI6NRoWsggqaW7WTH86ETlGccM8ROzKcjWhHil6vx81NvJltqVscRNQMYuoWUTOcnO53NhXwt9e2D2nBPnRP12pxL1h7MzkygHFjfSgsq6WuqR2jAhfPnUxlbQvZheZRtTdeMJ0HbjzFYj46gEde385rX+UDsGrZXLMp2He/LWTlKz8B3fd//Ldn2LSND8fvkCNxGiA/P9/RJjgEqVscRNQMYuoWUTOcnO5F8xLY98ZSNW2JTgeWXY1fMEXG9vw3GlOeDMSh8gZ+2lPOicZ2Nfjhs62H+jhwAK9/vU91xCwx0HjWxz8Uqz9/ub1EU21cOnEaQEuLJO2J1C0OImoGMXWLqBmsp3v18nMoeusminqsoQvwcVevD+So9Ux5Ehniy8VzJ+PSz8iTiHz4v/2cccc7zLn5TTUXnKIo3PjIF7z+9T61XO8UIz0dvPmnTNJUGxdvzFuD2GqaVutI3eIgomYQU7eImsE2uhfNSxhReoyen+u9fVhP3tlUwKq12zAKsqpKb1CoqmsFuh21RfMS+DbnMNvzK8zKNbR0kHjjWh648RQWzUvA0CPHyerl51BQMHhAhb2QI3EaICYmxtEmOASpWxxE1Axi6hZRMzin7kXzEih8a5k6eufmKs6oXUNLB+u+3Mtjb2VavN6pN/Kfj3cDmG0p9s6mAk3VtXTiNEBeXt7ghUYhUrc4iKgZxNQtomZwft3d6/KWCbXG7tE3Mzlc1X+qk8raFu7897dmI5Wr38vRVF2P+unU+vp65s2bh16vR6/Xc9ddd3HzzTc72iyJRCKRSDTN6uXn9JmOHU1RsUOZRf5mR6nZjhHNbdpKMzLqnTg/Pz+2bNmCt7c3LS0tJCYmsmDBAsaOHeto01QmTpzoaBMcgtQtDiJqBjF1i6gZxNHd27GrrKwkIiIC6J5qfOHTPFLjw/gqswS9wfnX2s1NGs+PecfUY4NR0VRdj3onztXVFW9vbwA6OjpQFGXAUGJHYDAMfbuS0YTULQ4iagYxdYuoGaRu6D+gYsWz3/P5tkN4urthNBrp6BpaPjwt8NOeY33OtbZrZzTO4WvitmzZwiWXXEJkZCQ6nY6PP/64T5k1a9YQExODp6cnGRkZ7NixY1jPqK+vZ+bMmUyYMIE//elPhIRoK0P20aNHHW2CQ5C6xUFEzSCmbhE1g9Q9EKa0KbvXLmbPa0u5eO5kXF10JE0OwdNd22NJlsZ8LvzzRp54Z3h+iK1w+NtraWlh5syZLFu2jAULFvS5/t5777FixQqef/55MjIyeOqpp7jgggsoKioiLCwMgJSUFPR6fZ/PfvPNN0RGRhIYGMju3bupqqpiwYIFXHXVVYSHh9tcm0QikUgkEnMsrbXrjWlqFhTKa1rsY9gweOm/e/jTonRHm6Gtbbd0Oh0bNmzg8ssvV89lZGQwZ84cnn32WQCMRiNRUVH87ne/47777hv2M26//XZ+9atfcdVVV1m83tHRQUdHh3rc2NhIVFSUTbfd6uzsxN3dffCCowypWxxE1Axi6hZRM0jd9kJLgRX7377JJvcdzrZbDh+JG4jOzk5ycnK4//771XMuLi7MmzePbdu2DekeVVVVeHt74+fnR0NDA1u2bOG2227rt/yjjz7KqlWr+pzPzs7Gx8eH1NRUCgoKaGtrw8/Pj0mTJqnhxhMnTsRoNHLkyBGge4SwuLiY5uZmfHx8iI+PZ9eu7g17J0yYgKurK2VlZTQ2NnL66adTWlpKY2Mjnp6ezJgxg5yc7qzRkZGReHp6cuhQd8NNTEzk6NGj1NfX4+7uTkpKijrFHBERga+vL8XF3duEJCQkUFVVRW1tLW5ubsyePZsdO3agKAqhoaEEBQWxf/9+AKZOnUptbS3V1dW4uLgwZ84csrOzMRgMjB07lrCwMDXJ4ZQpU2hsbKSqqnt7k4yMDHbu3ElXVxdBQUFERkaqW5PExsbS2tpKRUV3QsW0tDT27t3L8ePHiYqKIjo6mj179gDduZb0er06RJ+amkphYSGtra34+voSGxvL7t3duXuio6MBOHz4MAAzZ87k4MGDNDc34+3tzbRp09i5c6f6vt3c3CgtLQUgKSmJw4cP09DQgKenJ4mJiWRnZwMwbtw4vL29OXjwIAAzZsygvLycuro6xowZQ2pqKpmZ3bmFwsPD8ff358CBA+r7Pn78OCdOnMDV1ZW0tDSysrIwGo2EhoYSHBxMVlYW/v7+xMfHU1dXR3V1NTqdjvT0dHJyctDr9QQHBxMeHq6+77i4OJqbm6msrAQgPT2d3NxcOjs7CQwMZMKECezduxeAyZMn097eTnl5OQCzZ88mPz+f9vZ2/P39iYmJMWuzBoNBfd+zZs1i//79tLS04OvrS1xcHLm5uQBERUXh4uJCWVkZAMnJyZSUlNDU1ISXlxcJCQnq+x4/fjzu7u6UlJQA4O7ujre3N/X19Xh4eJCcnExWVpbaZn18fNT3PX36dCorK6mtre3zvsPCwggICFDf97Rp06ipqaGmpkZts6b3HRISQkhICIWFhWqbbWho4Pjx433abHBwMBEREezbt09tsy0tLer7njNnDnl5eXR0dBAYGEhUVJTaZidNmkRnZyfHjh1T26ypj2hvbyc9Pf2k+wjT+3aGPiIzMxN/f3+r9BHt7e0EBAQ4RR+xa9cu3N3drdJHFBUVAThFH6HX69UdDE6mj0hKSuLIkSOD9hEL53izavHVZn1EUa0XT7+fxSVpobz5vwr0RvuMTZWVlZ10H2HJj2hoaBiyDZoeiSsvL2f8+PFs3bqVU089VS13zz338L///U/9JRmIHTt2cMstt6gBDXfccQe33nprv+UdMRKXmZlJRkaGTe6tZaRucRBRM4ipW0TNIHVrkeQl62jvtF3AiRyJswOmv0aGioeHBx4eHrYzyAK+vr52fZ5WkLrFQUTNIKZuETWD1K1F8tYt6XPOFCmrneGrk0PTTlxISAiurq7qcLyJqqoqNS+NrVizZg1r1qxRw6dtOZ1qMBhoa2tziqkSa06ntrS0UFhY6BRTJdacTm1oaCAzM9MppkqsNZ0aHx9PUVGRcNOp3t7etLW1CTWdamrfok2n6nQ6MjMzhZtOjY6OVrXaYzoVTq6PWL38HBZl+A7YR8z/y3cMZUZWTqf2or/AhvT0dJ555hmgO7AhOjqa5cuXjyiwYbgMZ1hzpGh5ONqWSN3iIKJmEFO3iJpB6haFOTe/QUNLJwE+7mS9dINNnuFU06nNzc3qX4QAJSUl5ObmEhwcTHR0NCtWrGDx4sWkpaWRnp7OU089RUtLC0uXLnWg1RKJRCKRSEQj66UbNOW4OtyJy87O5pxzemR2XrECgMWLF7Nu3ToWLlxIdXU1K1eupLKykpSUFL766qtRlectKirK0SY4BKlbHETUDGLqFlEzSN0ioSXNDnfizj777EG3wVq+fDnLly+3k0X2x8XF4RtnOASpWxxE1Axi6hZRM0jdIqElzQ534rSKPQMb6urqOOuss5xi0bI1AxsqKiqIiYlxikXL1gxsyM3NJSgoyCkWLVsrsEGv19PQ0CBcYENTUxMBAQFCBTaY2rdogQ35+fmUlZUJF9jQ2NiotlFnCGywRh9RV1dHWlqaVfqIURXYoEVkYIPtkLrFQUTNIKZuETWD1C0SttY8HL9DOnGDYA8nrq2tDS8vL5vcW8tI3eIgomYQU7eImkHqFglbax6O36GdiV2BMQ0ni4bULQ4iagYxdYuoGaRukdCSZunEaYCmpiZHm+AQpG5xEFEziKlbRM0gdYuEljTLwIZ+sGdgQ2Njo5A7NjQ2Ngq5Y0NjY6NwOza4ubkJuWOD6Z9IgQ2m9i1aYENHR4eQOza4u7s71Y4N1ugjGhsbqampkYENzoA91sR1dXUxZswYm9xby0jd4iCiZhBTt4iaQeoWCVtrlmvinAzTXyaiIXWLg4iaQUzdImoGqVsktKRZTqcOgmmgsrGx0WbPaGlpsen9tYrULQ4iagYxdYuoGaRukbC1ZtO9hzJRKp24QTAtYNTSNhsSiUQikUhGN6Zk4QMh18QNgtFopLy8HD8/P3Q6ndXv39jYSFRUFEeOHLHZmjstInWLo1tEzSCmbhE1g9Qtkm57aFYUhaamJiIjIwfd4kuOxA2Ci4sLEyZMsPlz/P39hfkl6InULQ4iagYxdYuoGaRukbC15sFG4EzIwAaJRCKRSCQSJ0Q6cRKJRCKRSCROiHTiHIyHhwcPPvggHh4ejjbFrkjd4ugWUTOIqVtEzSB1i6Rba5plYINEIpFIJBKJEyJH4iQSiUQikUicEOnESSQSiUQikTgh0omTSCQSiUQicUKkEyeRSCQSiUTihEgnzsGsWbOGmJgYPD09ycjIYMeOHY42aUg8+uijzJkzBz8/P8LCwrj88sspKioyK3P22Wej0+nM/v32t781K3P48GEuuugivL29CQsL409/+hN6vd6szObNm0lNTcXDw4O4uDjWrVtna3n98tBDD/XRNG3aNPV6e3s7d9xxB2PHjsXX15crr7ySqqoqs3s4m2aAmJiYPrp1Oh133HEHMDrqesuWLVxyySVERkai0+n4+OOPza4risLKlSsZN24cXl5ezJs3jwMHDpiVqa2t5brrrsPf35/AwEBuuukmmpubzcrk5eVxxhln4OnpSVRUFI8//ngfW95//32mTZuGp6cnSUlJfPHFF1bXa2Ig3V1dXdx7770kJSXh4+NDZGQkN954I+Xl5Wb3sNQ+HnvsMbMyzqQbYMmSJX00XXjhhWZlnK2+B9Ns6Xdcp9PxxBNPqGWcsa6H8n1lz77bqt/7isRhvPvuu4q7u7vy6quvKvn5+crNN9+sBAYGKlVVVY42bVAuuOACZe3atcrevXuV3Nxc5de//rUSHR2tNDc3q2XOOuss5eabb1YqKirUfw0NDep1vV6vJCYmKvPmzVN27dqlfPHFF0pISIhy//33q2UOHTqkeHt7KytWrFD27dunPPPMM4qrq6vy1Vdf2VWviQcffFCZMWOGmabq6mr1+m9/+1slKipK+fbbb5Xs7GzllFNOUebOnated0bNiqIox48fN9O8ceNGBVC+//57RVFGR11/8cUXygMPPKB89NFHCqBs2LDB7Ppjjz2mBAQEKB9//LGye/du5dJLL1UmTZqktLW1qWUuvPBCZebMmcr27duVH374QYmLi1MWLVqkXm9oaFDCw8OV6667Ttm7d6/yzjvvKF5eXsoLL7yglvnpp58UV1dX5fHHH1f27dun/OUvf1HGjBmj7Nmzx+666+vrlXnz5invvfeeUlhYqGzbtk1JT09XZs+ebXaPiRMnKg8//LBZ/ffsC5xNt6IoyuLFi5ULL7zQTFNtba1ZGWer78E099RaUVGhvPrqq4pOp1MOHjyolnHGuh7K95W9+m5rf+9LJ86BpKenK3fccYd6bDAYlMjISOXRRx91oFUj4/jx4wqg/O9//1PPnXXWWcpdd93V72e++OILxcXFRamsrFTP/ec//1H8/f2Vjo4ORVEU5Z577lFmzJhh9rmFCxcqF1xwgXUFDJEHH3xQmTlzpsVr9fX1ypgxY5T3339fPVdQUKAAyrZt2xRFcU7NlrjrrruU2NhYxWg0Kooy+uq69xec0WhUIiIilCeeeEI9V19fr3h4eCjvvPOOoiiKsm/fPgVQsrKy1DJffvmlotPplGPHjimKoijPPfecEhQUpGpWFEW59957lalTp6rHV199tXLRRReZ2ZORkaHceuutVtVoCUtf7L3ZsWOHAihlZWXquYkTJyr/+te/+v2MM+pevHixctlll/X7GWev76HU9WWXXab86le/Mjvn7HWtKH2/r+zZd1v7e19OpzqIzs5OcnJymDdvnnrOxcWFefPmsW3bNgdaNjIaGhoACA4ONjv/1ltvERISQmJiIvfffz+tra3qtW3btpGUlER4eLh67oILLqCxsZH8/Hy1TM93ZCrjyHd04MABIiMjmTx5Mtdddx2HDx8GICcnh66uLjN7p02bRnR0tGqvs2ruSWdnJ2+++SbLli1Dp9Op50djXZsoKSmhsrLSzL6AgAAyMjLM6jYwMJC0tDS1zLx583BxcSEzM1Mtc+aZZ+Lu7q6WueCCCygqKqKurk4to9X3AN2/6zqdjsDAQLPzjz32GGPHjmXWrFk88cQTZtNMzqp78+bNhIWFMXXqVG677TZOnDihXhvt9V1VVcXnn3/OTTfd1Oeas9d17+8re/XdtvjedxvRpyQnTU1NDQaDwaxBAISHh1NYWOggq0aG0Wjk97//PaeddhqJiYnq+WuvvZaJEycSGRlJXl4e9957L0VFRXz00UcAVFZWWtRvujZQmcbGRtra2vDy8rKltD5kZGSwbt06pk6dSkVFBatWreKMM85g7969VFZW4u7u3ufLLTw8fFA9pmsDlXGU5t58/PHH1NfXs2TJEvXcaKzrnphstGRfT/vDwsLMrru5uREcHGxWZtKkSX3uYboWFBTU73sw3cORtLe3c++997Jo0SKzzb/vvPNOUlNTCQ4OZuvWrdx///1UVFSwevVqwDl1X3jhhSxYsIBJkyZx8OBB/vznPzN//ny2bduGq6vrqK/v1157DT8/PxYsWGB23tnr2tL3lb367rq6Oqt/70snTnLS3HHHHezdu5cff/zR7Pwtt9yi/pyUlMS4ceM499xzOXjwILGxsfY20yrMnz9f/Tk5OZmMjAwmTpzI+vXrHe5c2YtXXnmF+fPnExkZqZ4bjXUtMaerq4urr74aRVH4z3/+Y3ZtxYoV6s/Jycm4u7tz66238uijj2pme6Lhcs0116g/JyUlkZycTGxsLJs3b+bcc891oGX24dVXX+W6667D09PT7Lyz13V/31fOipxOdRAhISG4urr2iX6pqqoiIiLCQVYNn+XLl/PZZ5/x/fffM2HChAHLZmRkAFBcXAxARESERf2mawOV8ff314TTFBgYSHx8PMXFxURERNDZ2Ul9fb1ZmZ516uyay8rK2LRpE7/5zW8GLDfa6tpk40C/rxERERw/ftzsul6vp7a21ir178h+weTAlZWVsXHjRrNROEtkZGSg1+spLS0FnFd3TyZPnkxISIhZmx6t9f3DDz9QVFQ06O85OFdd9/d9Za++2xbf+9KJcxDu7u7Mnj2bb7/9Vj1nNBr59ttvOfXUUx1o2dBQFIXly5ezYcMGvvvuuz7D55bIzc0FYNy4cQCceuqp7Nmzx6wjNH1BTJ8+XS3T8x2ZymjlHTU3N3Pw4EHGjRvH7NmzGTNmjJm9RUVFHD58WLXX2TWvXbuWsLAwLrroogHLjba6njRpEhEREWb2NTY2kpmZaVa39fX15OTkqGW+++47jEaj6tSeeuqpbNmyha6uLrXMxo0bmTp1KkFBQWoZLb0HkwN34MABNm3axNixYwf9TG5uLi4uLup0ozPq7s3Ro0c5ceKEWZsejfUN3aPts2fPZubMmYOWdYa6Huz7yl59t02+90cUDiGxCu+++67i4eGhrFu3Ttm3b59yyy23KIGBgWbRL1rltttuUwICApTNmzebhZq3trYqiqIoxcXFysMPP6xkZ2crJSUlyieffKJMnjxZOfPMM9V7mEK2zz//fCU3N1f56quvlNDQUIsh23/605+UgoICZc2aNQ5Nt/HHP/5R2bx5s1JSUqL89NNPyrx585SQkBDl+PHjiqJ0h6lHR0cr3333nZKdna2ceuqpyqmnnqp+3hk1mzAYDEp0dLRy7733mp0fLXXd1NSk7Nq1S9m1a5cCKKtXr1Z27dqlRmE+9thjSmBgoPLJJ58oeXl5ymWXXWYxxcisWbOUzMxM5ccff1SmTJlilnKivr5eCQ8PV2644QZl7969yrvvvqt4e3v3Sb/g5uamPPnkk0pBQYHy4IMP2jT9wkC6Ozs7lUsvvVSZMGGCkpuba/a7borI27p1q/Kvf/1Lyc3NVQ4ePKi8+eabSmhoqHLjjTc6re6mpibl7rvvVrZt26aUlJQomzZtUlJTU5UpU6Yo7e3t6j2crb4Ha+OK0p0ixNvbW/nPf/7T5/POWteDfV8piv36bmt/70snzsE888wzSnR0tOLu7q6kp6cr27dvd7RJQwKw+G/t2rWKoijK4cOHlTPPPFMJDg5WPDw8lLi4OOVPf/qTWe4wRVGU0tJSZf78+YqXl5cSEhKi/PGPf1S6urrMynz//fdKSkqK4u7urkyePFl9hiNYuHChMm7cOMXd3V0ZP368snDhQqW4uFi93tbWptx+++1KUFCQ4u3trVxxxRVKRUWF2T2cTbOJr7/+WgGUoqIis/Ojpa6///57i2168eLFiqJ0pxn561//qoSHhyseHh7Kueee2+ddnDhxQlm0aJHi6+ur+Pv7K0uXLlWamprMyuzevVs5/fTTFQ8PD2X8+PHKY4891seW9evXK/Hx8Yq7u7syY8YM5fPPP3eI7pKSkn5/1005AnNycpSMjAwlICBA8fT0VBISEpS///3vZs6Os+lubW1Vzj//fCU0NFQZM2aMMnHiROXmm2/u80XrbPU9WBtXFEV54YUXFC8vL6W+vr7P5521rgf7vlIU+/bd1vze1/0sUCKRSCQSiUTiRMg1cRKJRCKRSCROiHTiJBKJRCKRSJwQ6cRJJBKJRCKROCHSiZNIJBKJRCJxQqQTJ5FIJBKJROKESCdOIpFIJBKJxAmRTpxEIpFIJBKJEyKdOIlEIpFIJBInRDpxEolE4kB0Oh0ff/yxo82QSCROiHTiJBKJsCxZsgSdTtfn34UXXuho0yQSiWRQ3BxtgEQikTiSCy+8kLVr15qd8/DwcJA1EolEMnTkSJxEIhEaDw8PIiIizP4FBQUB3VOd//nPf5g/fz5eXl5MnjyZDz74wOzze/bs4Vf/v327B0mujcMAftknKgWWFjY1JGJCDRVhH0OPUBgEhhGBhLSIZtLSEn3Z0BbVJgg1FQkOgVQW1ShEQWRB1lZLSEUNKeTi/Q4vCBJPxPvE+3Ty+oFw7vt/POd/O10c7/PrF+RyOSorK+F0OpFMJnPOWVtbg9FoRGlpKbRaLcbGxnLqT09P6O/vh0KhgE6nQzgcztZeXl5gt9uh0Wggl8uh0+nehU4iyk8McUREH5iZmYHNZkMsFoPdbsfQ0BDi8TgAIJVKoaenByqVCqenpwiFQjg8PMwJaX6/Hx6PB06nE5eXlwiHw6irq8u5x/z8PAYHB3FxcYHe3l7Y7XY8Pz9n7391dYVIJIJ4PA6/3w+1Wv3//QBE9H0JIqI85XA4RGFhoVAqlTmfhYUFIYQQAITL5cr5Tmtrq3C73UIIIQKBgFCpVCKZTGbrOzs7oqCgQCQSCSGEEDU1NWJqauq3PQAQ09PT2XEymRQARCQSEUII0dfXJ0ZGRr5mwUT0o3BPHBHlta6uLvj9/py5ioqK7LHJZMqpmUwmnJ+fAwDi8TgaGxuhVCqz9fb2dmQyGdzc3EAmk+H+/h5ms/nDHhoaGrLHSqUS5eXleHh4AAC43W7YbDacnZ2hu7sbVqsVbW1t/2mtRPSzMMQRUV5TKpXv/t78KnK5/FPnFRcX54xlMhkymQwAwGKx4O7uDru7uzg4OIDZbIbH48Hi4uKX90tE0sI9cUREHzg+Pn43NhgMAACDwYBYLIZUKpWtR6NRFBQUQK/Xo6ysDLW1tTg6OvqjHjQaDRwOB9bX17GysoJAIPBH1yOin4FP4ogor6XTaSQSiZy5oqKi7MsDoVAIzc3N6OjowMbGBk5OTrC6ugoAsNvtmJubg8PhgM/nw+PjI7xeL4aHh1FdXQ0A8Pl8cLlcqKqqgsViwevrK6LRKLxe76f6m52dRVNTE4xGI9LpNLa3t7MhkojyG0McEeW1vb09aLXanDm9Xo/r62sA/745GgwGMTo6Cq1Wi83NTdTX1wMAFAoF9vf3MT4+jpaWFigUCthsNiwtLWWv5XA48Pb2huXlZUxMTECtVmNgYODT/ZWUlGBychK3t7eQy+Xo7OxEMBj8gpUTkdTJhBDibzdBRPQdyWQybG1twWq1/u1WiIje4Z44IiIiIgliiCMiIiKSIO6JIyL6De42IaLvjE/iiIiIiCSIIY6IiIhIghjiiIiIiCSIIY6IiIhIghjiiIiIiCSIIY6IiIhIghjiiIiIiCSIIY6IiIhIgv4B074HgukS0T8AAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(7, 4))\n", "\n", "plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='#25599c', markersize=1)\n", "\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Loss')\n", "plt.title('Training Loss Over Epochs')\n", "plt.yscale('log')\n", "\n", "plt.legend()\n", "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "4860724c-a72f-41cf-8882-902e5c05f601", "metadata": {}, "source": [ "Additionally, we observe a good approximation to the actual solution to the Allen-Cahn equation, which cannot be obtained without utilizing any adaptive technique." ] }, { "cell_type": "code", "execution_count": 12, "id": "7798eb68-d6b5-47ec-b845-cf3d60dccf23", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAApsAAAGGCAYAAAA0Mkq8AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAmilJREFUeJztnXmcFNW593/VPRuLM6DADOgoAomIiijIBIxxGx2UF4PRKGoUSAQ3NIpeFRcWUXGLQQ2Rq1HBGwheE/FqMCiiJDeRC4oSNyRRUVCZUUQY1ple6v1jnHb61NPTT5061V3d/XzzqU/s4pxTp0/XdP362Y5l27YNQRAEQRAEQfCBULYnIAiCIAiCIOQvIjYFQRAEQRAE3xCxKQiCIAiCIPiGiE1BEARBEATBN0RsCoIgCIIgCL4hYlMQBEEQBEHwDRGbgiAIgiAIgm+I2BQEQRAEQRB8Q8SmIAiCIAiC4BsiNgXBICeccAJOOOEEo2N+8sknsCwL8+bNMzquaZYuXYpBgwahrKwMlmVh27ZtRsa1LAvTp09PvJ43bx4sy8Inn3xiZPwgMG7cOHTu3Dnb0wg8K1asgGVZWLFiRbanIgiCC0RsCgXNO++8g7PPPhsHHXQQysrKsP/+++OUU07BQw89lPG5LFy4ELNnz874dU3w9ddf45xzzkGHDh0wZ84c/Nd//Rc6deqUtt9vf/tbWJaFmpqaDMzSHI2NjZgxYwaOPPJIdO7cGR06dMDhhx+OG264AV988UW2p5eSVrGW6li0aFG2pwig5b4I+o8rQRD4FGV7AoKQLV577TWceOKJOPDAAzFhwgRUVVVh06ZN+L//+z888MADuPLKKzM6n4ULF+Ldd9/F1VdfnXT+oIMOwp49e1BcXJzR+bjh9ddfx44dOzBz5kzU1tay+y1YsAC9e/fG6tWr8eGHH6Jfv34+ztIMH3/8MWpra7Fx40b89Kc/xcSJE1FSUoK3334bjz32GBYvXox//etf2Z5mu1x11VU45phjHOeHDRuWhdk4+e1vf4tu3bph3LhxSed/9KMfYc+ePSgpKcnOxARB0ELEplCw3HHHHaioqMDrr7+OLl26JP3bl19+mZ1JEViWhbKysmxPo11a10tdx/bYsGEDXnvtNTzzzDO45JJLsGDBAkybNs2nGZohGo3iJz/5CRoaGrBixQr88Ic/TPr3O+64A3fffXeWZsfnuOOOw9lnn53tabgmFAoF/m9BEAQn4kYXCpaPPvoIhx12GCmQevTokfQ6Go1i5syZ6Nu3L0pLS9G7d2/cdNNNaGpqavcaqeIL1dizE044AUuWLMGnn36acGn27t0bQOqYzVdeeQXHHXccOnXqhC5duuDHP/4x1q1bl9Rm+vTpsCwLH374IcaNG4cuXbqgoqIC48ePx+7du9OuEQA8/fTTGDx4MDp06IBu3brhZz/7GT7//PPEv59wwgkYO3YsAOCYY46BZVkOixTFggUL0LVrV4wcORJnn302FixYwJpPKv7yl78k1mOfffbByJEj8d577yW1aY2N/PzzzzF69Gh07twZ3bt3x3XXXYdYLJb2Gn/605/wz3/+EzfffLNDaAJAeXk57rjjjsTr//3f/8VPf/pTHHjggSgtLUV1dTWuueYa7Nmzhxw/3bxa74X77rsPjzzySOJ+POaYY/D6669zl4pFU1MTrrnmGnTv3h377LMPzjjjDHz22WeOGNpx48Yl7tW2tN57bXniiSdw0kknoUePHigtLcWAAQPw8MMPJ7Xp3bs33nvvPfz1r39N/C20xkGnitlMd4+2ztPLZy8Igj5i2RQKloMOOggrV67Eu+++i8MPP7zdthdffDHmz5+Ps88+G9deey1WrVqFWbNmYd26dVi8eLHnudx8883Yvn07PvvsM/z6178GgHYTRl5++WWcdtpp6NOnD6ZPn449e/bgoYcewrHHHos333zT8fA/55xzcPDBB2PWrFl488038bvf/Q49evRIa4WbN28exo8fj2OOOQazZs1CQ0MDHnjgAfzjH//AW2+9hS5duuDmm2/GIYccgkceeQS33XYbDj74YPTt2zfte16wYAF+8pOfoKSkBOeddx4efvhhvP7666R7Nx3/9V//hbFjx6Kurg533303du/ejYcffhg//OEP8dZbbyWtRywWQ11dHWpqanDffffh5Zdfxq9+9Sv07dsXl112WbvXee655wAAF154IWteTz/9NHbv3o3LLrsM++23H1avXo2HHnoIn332GZ5++umktm7mtXDhQuzYsQOXXHIJLMvCPffcg5/85Cf4+OOPWeEWO3bswJYtWxzn99tvv4RAvPjii/H73/8e559/PoYPH45XXnkFI0eOZL3vVDz88MM47LDDcMYZZ6CoqAjPP/88Lr/8csTjcVxxxRUAgNmzZ+PKK69E586dcfPNNwMAKisrU47JuUdb8fLZC4LgAVsQCpSXXnrJDofDdjgctocNG2Zff/319osvvmg3NzcntVu7dq0NwL744ouTzl933XU2APuVV15JnDv++OPt448/PvH6iSeesAHYGzZsSOr76quv2gDsV199NXFu5MiR9kEHHeSY54YNG2wA9hNPPJE4N2jQILtHjx72119/nTj3z3/+0w6FQvZFF12UODdt2jQbgP3zn/88acwzzzzT3m+//VItjW3btt3c3Gz36NHDPvzww+09e/Ykzv/5z3+2AdhTp051vM/XX3+93TFbeeONN2wA9rJly2zbtu14PG4fcMAB9i9/+UtHWwD2tGnTHNdqXdMdO3bYXbp0sSdMmJDUr76+3q6oqEg6P3bsWBuAfdtttyW1Peqoo+zBgwennfdRRx1lV1RUsN6jbdv27t27HedmzZplW5Zlf/rpp67n1Xov7LfffvbWrVsT5//nf/7HBmA///zz7c6n9b5LdWzevNm27e/u+csvvzyp//nnn+/4PMaOHUvet633Xrr1qKurs/v06ZN07rDDDkv6O1Ln3/p34+Ye9frZC4Kgj7jRhYLllFNOwcqVK3HGGWfgn//8J+655x7U1dVh//33T1iwAOCFF14AAEyePDmp/7XXXgsAWLJkSeYmDWDz5s1Yu3Ytxo0bh3333TdxfuDAgTjllFMS823LpZdemvT6uOOOw9dff43GxsaU13njjTfw5Zdf4vLLL0+Kkxs5ciT69+/v6X0vWLAAlZWVOPHEEwG0xKWee+65WLRokWuX5rJly7Bt2zacd9552LJlS+IIh8OoqanBq6++6uhDrcfHH3+c9lqNjY3YZ5992HPr0KFD4r937dqFLVu2YPjw4bBtG2+99Zb2vM4991x07do1qR0A1nsAgKlTp2LZsmWOo/V+ar2HrrrqqqR+avKaW9qux/bt27FlyxYcf/zx+Pjjj7F9+3bX4+nco7qfvSAI+ogbXShojjnmGDzzzDNobm7GP//5TyxevBi//vWvcfbZZ2Pt2rUYMGAAPv30U4RCIUemdFVVFbp06YJPP/00o3Nuvd4hhxzi+LdDDz0UL774Inbt2pVUeujAAw9MatcqVL755huUl5e7vk7//v3x97//XWv+sVgMixYtwoknnogNGzYkztfU1OBXv/oVli9fjlNPPZU93r///W8AwEknnUT+u/r+ysrK0L1796RzXbt2xTfffJN4/dVXXyWJ3s6dO6Nz584oLy93JUw2btyIqVOn4rnnnksaH4BDXHHm1Up7nyeHI444ot2qAa33vBoOQd0LbvjHP/6BadOmYeXKlY6Y4e3bt6OiosLVeG7vUTdrLAiCOURsCgKAkpISHHPMMTjmmGPw/e9/H+PHj8fTTz+dlB2tJjtwSNUn0wkJ4XCYPG/bdkbnAbQkNm3evBmLFi0i6zouWLDAldiMx+MAWuI2q6qqHP9eVJT8NZdqLdpyzDHHJP2ImDZtGqZPn47+/fvjrbfewqZNm1BdXd3uGLFYDKeccgq2bt2KG264Af3790enTp3w+eefY9y4cYl5u5lXurbZ+Dy59/hHH32Ek08+Gf3798f999+P6upqlJSU4IUXXsCvf/1rx3r4gZs1FgTBHCI2BUFhyJAhAFrc1UBLIlE8Hse///1vHHrooYl2DQ0N2LZtGw466KCUY7VanNTddChrKFfMtl5v/fr1jn/74IMP0K1bN1ZBdTfXUa2G69evb/d9t8eCBQvQo0cPzJkzx/FvzzzzDBYvXoy5c+cmuVzbo9X61qNHD1c1PtPNsW3GeJ8+fQAAo0aNwh/+8Af8/ve/x5QpU9od45133sG//vUvzJ8/HxdddFHi/LJly4zM0U9a7/mPPvooyWpI3XNdu3Yld4tS7/Hnn38eTU1NeO6555Iss1SYg87fgsl7VBAEs0jMplCwvPrqq6QlqDVerfUhe/rppwOAY3ef+++/HwDazdBtFUJ/+9vfEudisRgeeeQRR9tOnTqx4tZ69uyJQYMGYf78+UkP+XfffRcvvfRSYr5eGTJkCHr06IG5c+cmlXj6y1/+gnXr1mllJu/ZswfPPPMM/t//+384++yzHcekSZOwY8eOpJjZdNTV1aG8vBx33nknIpGI49+/+uor1/M89thjUVtbmzhaxebZZ5+NI444AnfccQdWrlzp6Ldjx45EBnWrFa3tPWbbNh544AHX88k0p512GgDgwQcfTDpP7XDVt29fbN++HW+//Xbi3ObNmx1VGqj12L59O5544gnHmJ06dWJtd+rHPSoIgnnEsikULFdeeSV2796NM888E/3790dzczNee+01PPXUU+jduzfGjx8PADjyyCMxduxYPPLII9i2bRuOP/54rF69GvPnz8fo0aMTSS4Uhx12GH7wgx9gypQp2Lp1K/bdd18sWrQI0WjU0Xbw4MF46qmnMHnyZBxzzDHo3LkzRo0aRY5777334rTTTsOwYcPwi1/8IlH6qKKiIqkGoheKi4tx9913Y/z48Tj++ONx3nnnJcrK9O7dG9dcc43rMZ977jns2LEDZ5xxBvnvP/jBD9C9e3csWLAA5557LmvM8vJyPPzww7jwwgtx9NFHY8yYMejevTs2btyIJUuW4Nhjj8VvfvMb13OlKC4uxjPPPIPa2lr86Ec/wjnnnINjjz0WxcXFeO+997Bw4UJ07doVd9xxB/r374++ffviuuuuw+eff47y8nL86U9/CkR84P/+7/9i7969jvMDBw7EwIEDMWjQIJx33nn47W9/i+3bt2P48OFYvnw5PvzwQ0efMWPG4IYbbsCZZ56Jq666KlF26vvf/z7efPPNRLtTTz0VJSUlGDVqFC655BLs3LkTjz76KHr06JHwIrQyePBgPPzww7j99tvRr18/9OjRg4zJ9eMeFQTBB7KYCS8IWeUvf/mL/fOf/9zu37+/3blzZ7ukpMTu16+ffeWVV9oNDQ1JbSORiD1jxgz74IMPtouLi+3q6mp7ypQp9t69e5PaqaWPbNu2P/roI7u2ttYuLS21Kysr7ZtuusletmyZo/TRzp077fPPP9/u0qWLDSBRToYqfWTbtv3yyy/bxx57rN2hQwe7vLzcHjVqlP3+++8ntWktP/PVV18lnU9Vkoniqaeeso866ii7tLTU3nfffe0LLrjA/uyzz8jx0pU+GjVqlF1WVmbv2rUrZZtx48bZxcXF9pYtW2zbTl/6qJVXX33VrqursysqKuyysjK7b9++9rhx4+w33ngj0Wbs2LF2p06dHNekyvS0xzfffGNPnTrVPuKII+yOHTvaZWVl9uGHH25PmTIlUT7Itm37/ffft2tra+3OnTvb3bp1sydMmGD/85//dHye3Hm13gv33nuvo626ThTpSh+17b9nzx77qquusvfbbz+7U6dO9qhRo+xNmzaR13nppZfsww8/3C4pKbEPOeQQ+/e//z25ps8995w9cOBAu6yszO7du7d99913248//rjj86yvr7dHjhxp77PPPjaAxN8UVTLMtnn3qKnPXhAE91i2nYWIckEQBCEnsSwrkTAlCILAQWI2BUEQBEEQBN8QsSkIgiAIgiD4hohNQRAEQRAEwTdySmz+7W9/w6hRo9CrVy9YloVnn302bZ8VK1bg6KOPRmlpKfr164d58+Y52syZMwe9e/dGWVkZampqsHr1avOTFwRByANs25Z4TUEQXJFTYnPXrl048sgjyWLQFBs2bMDIkSNx4oknYu3atbj66qtx8cUX48UXX0y0aS01M23aNLz55ps48sgjUVdXhy+//NKvtyEIgiAIglAw5Gw2umVZWLx4MUaPHp2yzQ033IAlS5bg3XffTZwbM2YMtm3bhqVLlwJo2Y/5mGOOSdThi8fjqK6uxpVXXokbb7zR1/cgCIIgCIKQ7+R1UfeVK1c6tq+rq6vD1VdfDQBobm7GmjVrkradC4VCqK2tJXcHSUU8HscXX3yBffbZR2v/bEEQBEEQ+Ni2jR07dqBXr14IhTLrpN27dy+am5u1+paUlKCsrMzwjIJPXovN+vp6VFZWJp2rrKxEY2Mj9uzZg2+++QaxWIxs88EHH6Qct6mpKWlrtM8//xwDBgwwO3lBEARBENpl06ZNOOCAAzJ2vb1796KqQwW2Q09sVlVVYcOGDQUnOPNabPrFrFmzMGPGDMf5s06ZjeLiDlmYkSAIgiAUDpHIHvxp2dXYZ599Mnrd5uZmbEcz7sNwdHApofYgiuvqX0Nzc7OIzXyiqqoKDQ0NSecaGhpQXl6ODh06IBwOIxwOk22qqqpSjjtlyhRMnjw58bqxsRHV1dUId+iIcHFHAIAVD24obCjAcxMEQRAELtkKXesUKkYHy52ECtkWEPdpQgEnr8XmsGHD8MILLySdW7ZsGYYNGwagJXZi8ODBWL58eSLRKB6PY/ny5Zg0aVLKcUtLS1FaWuo4Hy0OwSpuiR3xIuhCMX/FYBDudRG8giAIQq4SCgMhlzo3ZCMYD+AskFNic+fOnfjwww8Trzds2IC1a9di3333xYEHHogpU6bg888/x5NPPgkAuPTSS/Gb3/wG119/PX7+85/jlVdewX//939jyZIliTEmT56MsWPHYsiQIRg6dChmz56NXbt2Yfz48a7nFwuHECpqEZtxppiiRFeccQd7ErPaPZ3oWnA575GLCFdBEAQhk1ghCyGXVlXLLtwE4pwSm2+88QZOPPHExOtWV/bYsWMxb948bN68GRs3bkz8+8EHH4wlS5bgmmuuwQMPPIADDjgAv/vd71BXV5doc+655+Krr77C1KlTUV9fj0GDBmHp0qWOpCEOOpZNm2nFVEUdJdb8Fl2UxdXWFI0mwwy8CFcRqoIgCIJbwmEg7PLREy7gx03O1tkMEo2NjaioqMCIsf+F4pKOKdtRAisUd9rUKQHEca1zBRxHYAU5DCAI8bAiUgVBELJHc2QPFr1wCbZv347y8vKMXbf1ef9E5xPR0WXM5m47ivE7X834nINATlk2g060NAyUhAGkEpaEZTDm/GlEueBDoWRRSood4mcWJfyCGjLCFakca6rfgpRrTRVRKgiCkH+EQpoxmwVKTm1XKQiCIAiCIOQWYtk0SLQoBLTGbFLxjWQykJ4FlLR+Ei55EoYFlGv95CQ4ca17caZllgM3ljQIFlCxfgqCIOQWobD7BKGQJAgJJogVhWC1ZqMzRWRRxCnrOEKPMknHibNxoiFHlHJN3iZd8qRwZURge4kRpURppmNCs5HsJQiCIOgTslpc6a76BDWGLQOI2DRIc2kI8dKWmE1KRFKWTQqOVZQrZumYTVqq6kCNpJs5rxsHyRGkgMSECoIgCGYIhS2EXAZturWE5hMiNg0SC39n2aQES5gQoBQcIalrEQV4VlFdi2iq8Z3X45FLbvpsZMnrroUgCIKgTzjUcrjq489UcgIRmwaJloZhf2vZ5ApLClIwRJPHixY773KTcaJBsIgCPKuoFzd0puNExSIqCIKQ+4RCGpZNFK5lU7LRBUEQBEEQBN8Qy6ZBQkUtBwDEmDqeskSFo+mthVyLJeVu17VPcpOS1NG4iT+6SUkm4z8Bf5OSgup+B8TaKQiCwMUKu08QKuCQTRGbJikuiaGoNAYAiBLBHNGQM2LDZghLCkrsFBFjke52Rla8yfhPLznrnC0yTcZ/pmrn6FcA8Z+ACFBBEASKUAgabvTCRcSmQYqK4ygqdiesopq3n83cioAT/wk4RSkpnJjC2Ck1zRZSUufGFWu6WfFBif8MggAV8SkIgtAqNl328WcqOYGITYOUlMYTlk0vcARonCtiqESlImJ8lpB09uO48ylXPoVuAhLH+gkEIwHJdE1QFUlAEgRB8J9QyEKIWXYv0aeAE4REbBokXBRHUVFqcRSissDDvIeyQ4B62GSV2n1Iha4JqmfZ5LrydWNCuehmwAe5JqhKUC2igAhQQRDyB63SRwX8FVjIVl1BEARBEATBZ8SyaZDSkhiKDbjRKVSraDTi/J3AzYDnJBKRrm/K/c5xc2v1aq9v+gQkLxnwujsgUUj853dI/KcgCPlCKCxudDeI2DRIugQhyo1uEirWU1eAUq7voApQrqvdZAkmaq6sLPYAF6CXPeEFQRB4aCUIFfDXm4hNg5SUxFFcol/mxw90BSj3oc8RoEGt/wnwBCg3AYk1V4PllwD/E5BEgAqCIDjR2kHIFsumYIBwUTrLpv7YJq2iLAHKtmISzZS+VC+u5TQbApSDbga8bvmlVO0c/QxaPwHnewqC+x0QASoIQnaxwgBROrv9Pv5MJScQsWmQktIYSr6N2fTbZW4aVYCS1k9NAcoVkboC1Gz8JxCKKwXuuZnnDAFqsv5nqnaOfnnmfgdEgAqCkF3EsukOEZsGKS2yUVLU+sALhjtdV/Ry3e90iaT0eNkBiWM5NSlAVfEJ6AtQKUDvH5KAJAhCpghpWDYLufxPIb93QRAEQRAEwWfEsmmQ4hBQ0irfiyirSjCsnbpQ1s54lLDwEdZIJ4QVkxhft0yT39nuutbOXNrtCMhta6e42gVB8IuwZSHs0o0ejosbXTBAWQgobc+s7rMANRknyh2LJUBZ4hPQFaDcHYooUWoy2SgIAjSXst2D4GoHRIAKguAerdJHBexLFrFpkOJwy+EKowKUupP9KTLfHqpA1Ld+AhwBSlk/qRXkJiD5KUBNJhsBuZ3tHgTrJyACVBAE92gVdZcEIcEEpWGgzK3YpMgzAUpaP+NefuIplk2m+11XgAY12QiQbHe/YFmkRZAKQsEilk13iNg0SEnIRknC/Wz4F4xDgPIkUChMPRApReyfAKVc8tGo86+OEo089OM/OQJU1/oJ+Bv/CUi2ezYRi6ggFC5WyIblMnTNbft8ooB1tiAIgiAIQrCZM2cOevfujbKyMtTU1GD16tUp255wwgmwLMtxjBw5MtFm3Lhxjn8fMWKEr+8h5yybc+bMwb333ov6+noceeSReOihhzB06FCy7QknnIC//vWvjvOnn346lixZAqBl0efPn5/073V1dVi6dKnruSW70alfMPrWTqehyDl+KOTFOqlaOzMf60kRBGtnEGI9gfzLdje9taZKUNzvYu0UhPzDCrUcbvu45amnnsLkyZMxd+5c1NTUYPbs2airq8P69evRo0cPR/tnnnkGzc3Niddff/01jjzySPz0pz9NajdixAg88cQTidelpaXuJ+eCnBKbQV/04lC6vBeeACVjjqNaUwJKdUVjZl3tbsi0AA1qshGQ29nuprfWVAmK+10EqCDkH5Zlw7JcutFdtgeA+++/HxMmTMD48eMBAHPnzsWSJUvw+OOP48Ybb3S033fffZNeL1q0CB07dnTontLSUlRVVbmejy45JTaDvugdwi2HG8LkzUcICPWT0hWfgO8CNEzGifqHvvgEbEYMTVCTjWiYsbwFkO0eZGS3I0HIbbxYNhsbG5POl5aWkkau5uZmrFmzBlOmTEmcC4VCqK2txcqVK1nXfOyxxzBmzBh06tQp6fyKFSvQo0cPdO3aFSeddBJuv/127Lfffu7ekAtyRmzmwqIXWS4q+rQLQ4BSn1xgBGj2i9ezBSi13zvjQyTFAaMAPSmciH7kNYlzTjGoX84+09nuQUk2opAC9IIgtIcVsl3Xtm5NEKqurk46P23aNEyfPt3RfsuWLYjFYqisrEw6X1lZiQ8++CDt9VavXo13330Xjz32WNL5ESNG4Cc/+QkOPvhgfPTRR7jppptw2mmnYeXKlQiHTZTUcZIzYjNIi97U1ISmpqbE69ZfKend6E74ZbrUm5ph/QSy4H5v+RGg18/fh6u2AKU+VEqkFhHtWEJSt5+zJy0GzQnQoMZ/AvnnghcBKgjBxbI0LJvf/klv2rQJ5eXlifN+xUs+9thjOOKIIxx5LWPGjEn89xFHHIGBAweib9++WLFiBU4++WRf5pIzYtMrJhd91qxZmDFjhq/zFQRBEAQh/ygvL08Sm6no1q0bwuEwGhoaks43NDSkDf3btWsXFi1ahNtuuy3tdfr06YNu3brhww8/FLEZpEWfMmUKJk+enHjd2NiI6upqlIZtlIZbLUZ+V5VyWjhoKymRgMQwWpJjEdZOv62RfsOydnJd7RxrJ9NiqWvt5LjaAX1rZ1CTjQCz2e4Ust2mIAitZKLOZklJCQYPHozly5dj9OjRAIB4PI7ly5dj0qRJ7fZ9+umn0dTUhJ/97Gdpr/PZZ5/h66+/Rs+ePV3Nzw05IzaDtOipgnmLQjaKEzeT88EdsigFp1cOSd/9rn9Nir0m93bXTCzS7ZcKjgCNc4WMKkC1Xe0AJSVth9Ag7jvm6HQCkn8CNNeTjYLgagdEgApCNshU6aPJkydj7NixGDJkCIYOHYrZs2dj165diUTpiy66CPvvvz9mzZqV1O+xxx7D6NGjHfknO3fuxIwZM3DWWWehqqoKH330Ea6//nr069cPdXV17ifIJGfEJhD8RS8K2SgyYOkLk6LUJJRVlFFChzmtiMUQKOQ6+VtuSdcKS4pP3c/ZaKwnATFWnJgrJUZ4GfD6Py548aX6ApTVj/krLdd3O5Jsd0Hwl0xtV3nuuefiq6++wtSpU1FfX49BgwZh6dKlifyVjRs3OvIk1q9fj7///e946aWXHOOFw2G8/fbbmD9/PrZt24ZevXrh1FNPxcyZM32ttZlTYjPoi56cIGT2y71MsUZyhR8F3VeZbxHRSDfZqMSL9TO9ADXtylctpeR2mxHnt0ZMM3QizhQCTismwLI8avWi+3Ktnxyhx3f5a1o2A5KAFAQBKtZPQTBLpupsAsCkSZNSenBXrFjhOHfIIYfAtulrdejQAS+++KLWPLyQU2ITCPaiFye50ali7YQ1idB0PCHJU5t8d7vaj+d+DzFUCzkHQoBy3eG5lO3OEaDchz5VNN5pydQvo2SyAD1HgJKCjhgpqALU5A5IQbB+AiJABYFLptzo+UIBv3VBEARBEATBb3LOshlkwm1iNkOEZTASNxeLSVkL6QQkHjxrKmWZTd+Rk/3eMhZlVdHLgNc0fpLje7GIqtZOytJJu8d5cKxrVMwmZSXNhrWTAz/BiTFWQBOQguBqB8TaKQhcMpGNnk+I2DRIWzd6cL6feXuvc9ztdL/07nauK59uRwjcUHr1ynfJE8lSulnxVCKOci4adkonKv6Ti+52m9TWmpSo8HO7Td293oHM73aUqp2jnwhQQSgIxI3uDhGbBimCjaJvxZcXIyZttVRf8zLKKcujbgkmvki1GW306n+S43mo/0nFf3KEKiksDZZg4u52xC7BpGByu03KMsvJgLdjXDGV/d2OqHZBSTai8FuU6lYDEIR8oSUb3d097sXjluuI2DRIWssmVUGH+M6m8kCcUF/25r7caeFKtUs/Fiv7PdU1qTVTsuKpOewlyi9RYrA55Mx25yQgcayY7H7EvMgMeGJesajerxr9DPjgJiBxMJmAxLbCapZgkgQkQQguYtl0h4hNg7Qt6k4+J4jnI2nMI25I1V3NtVhyrKTUOW5MqK7YJAUisWhBiAmlRSoxvq7Y9HBOdctHCfFsE4sd5v2i0Y4JzbRL3t2I6cenkJhQ74gAFfIFCxqljwyXRMwlRGwapG1Rd651nWxHPNV03fJ8sam6vs255Lnud/Ic8XBV2+1limBKlFIxp3tDenU8OS755ibKkpp5lzz10KfOUfGeKkFwyQNOt7wXlzyFn0XpxSXfPiJIhaAhlk13FPBbFwRBEARBEPxGLJsGSRezSf2ApwxHlLs3qpg2yVhG7fhPrhudd07t20QkgnBd6+Q1FQslaZ1kxn8Wk9bO5NfNROV6yvIYJbaKVC2Z5OfmwY2uxpyy4z+ZOyCpLnjKChUmbmKOlVRKMmlcMwDxn4DsgCQIUvrIHSI2DRKywghZLQ9/KjYjxNgzHOC54Lnud0p8UC5ytQaorvudGp9ytRcRcQGUKOW42/cy3ePFxPpwXPC67nfAKfSam3lu9CJishwXfDRKCF4yCSp9/CfgFKW+u9+Z7TgueF33O2DWBc9JSsqlkkwAzwUfBPc7IKJU8AdLY2/0Qnaji9g0SFuxaRNik/rOKyasZhyrqK5FFOBZRSmLKDU+J2lI1yLaMg/qIZm+XzMxf0psUqKxWY2n07SItoyVfFFdi2iqeRQVJY/fxIwJpa5JCVW1HdmGmZSkigNdi2hLu/RCkrKSmk1KMhf/SREEiyiQf0lJIj4FE4hl0x0iNg1iIQSr9TFCld4B8VAmShhxrKKUSKKEGS3gnOfCyv7xlIiM23rucF2LaKp5qOcoQU1aMZkCsVh5GJFrSLmTybHUNsTnRphOSYHIEIOkmCVc5k2UtZPhgvdSpimiXJMtLImFpYQkZyemINcJdVyPOJfrZZqCkJQkFlHBBJIg5A4Rmwax2lo2CWFmE0KDEpuUVVRtR5ZQIKyk1DOAs5Um9QOMtHYy3PSUYCwmvvCp8YuIL/xiZa6UyKbc9HT4gPOaqkAsIb4gqDJNxKY82KsIEnosYg7EZ7k3RmR4O8SmU0RGi/Td9BGGZVPXTU9ZRGPEIlJlmijRorbjxpdmw02v1jmldlMiYbrDTZZp4vQ1bSUNqpteBKjQimVplD5y2T6fKGCdLQiCIAiCIPiNWDYNErJCbWI2mRZLTQsoNZZFxU/aznYxm7DuKF3jRD+2az3N2G7Gz0ZMqGr0o+NXneco132JYv3aQ3Qsdhr4sDfqPEdZRfcoFlCuxZKyUHJiQul+ejGh7LhRzZhQLwk2HAsoaaVjJEYBzr8RKks+TsVWMy2gJmNCc6lOaBCSksT6WRiENBKEZLtKwQjWt/8DWlzqKhz3eMp2iqjz4pKnRGnITvblmizdFCYEXYwZ/1lMCBlVIFIuc9XVTvXj9iUFKdMdriYqUe+HdqMzXf7R9G32EFnykbjznCosAaf4o4Qr1Y9KVNIWm8x2kaLkr7MQMwHJZOkmbnyp2pfrymeLUjUxjRrf2Y0kCIlKue5+pxBRmttIgpA7RGwaJCkbnbDcUV/buqLUpJUUwHeJTd9ClWlSBSlAJ7yoQpJ6lsSI9Skm5k9moyunOIIU0Bel3H4cgUgLS+e5CPEg4ohS/vjOc5QobVIEFkeQAkBRMZWolDxXjiAFzGbOq4IUAGKE2KTiRFVRRAkgOnM+vTWVI0ipfqnmwcmc51p5M505T83DZExoEJKUABGluY4kCLlDxKZB2majU7s4kpZHTVFq0koKOEUpR5ACtChVXfdhog1l2eS61lVRyhGkgL4oNWkl5VosSSsmQ5R2IFzye4gEFY6VFAAiSgLYniJzVlKOIKX6ASmEKqdMk6aVFHCKUkqQchOcONniXGHJcflzM+c5VlJq/EK1kgLBEKUiSLNAyKItKen6FCgiNg2SbNmkvladX8m6opS6ZW1iMF1RajKbnhtLahOPIk68J0eQAikEIiFAOW70Uub4qs6gyjRR/ZpiznuFY6Ek64uSItvZroyKHVVEShlhOSXDB4hyTntjyedKSpxtqKL3VOkmjpiNEP3IbHqGcCXHJwQp5bqnxnK65PVjSTlxkNRYlEjSzqZnCFIARrPpswFZpUBTQJgUqVwrqYqIVH2skAXL5bq7bZ9PFLBRVxAEQRAEQfAbsWwapG2CEBWcQVojCQuoRRV/d7i5mf0oCyLDAsp1yYMsVB9XXlOWC6dVi7qmms1Njc91j9MubKKdstZF1BaNhGWTl2zkaELXEmW+J/UcZTmlLJbUPCgLZZniRifbENekLKxqhn2EsE7uLXGeixCWNdIaqVgym5vTb78J8ONQ1b6cNqnONSl9w8QcqPjPKGHeJveTVzcmoFztPhe4V62fLf306olSc435nI1eCG56iRv1QDjUcrjtU6CI2DSIZYVgfSsy2cKSmSAEpS8pLMnMc2IohpueEqQUuslMFje+lIj3VOdKxo2SiUvEA50UkkobQkyRYpNYs6gyfinx0XJEJMBzwZMuf+b4HFG6hxCbVDITmYAUTi+8qZJP1Ph7ipwNHSEFzASkkiZC1DFiR8m4UabrniNcmyLFjnOUmz5GCFWH2GQIUkC/6D13LEqUEreUQ2CZdNMH1UUPiJs+pwhZ7mMwC9iNLmLTIEnbVdINeDBEKS3giKG0LadMscmwnJLimRLZxFxpMZveckqJ1BAx1zCRYc+xnFIJTlQikfrsoARjCfEFFGXGhKrf5boildu3I9GGsnZSQrKjMllSbBLfSM3cayrn9hLJTM1EMlNTB6dw5YhG2rpKxN9yykBpitRU7ZqUdpRIZQtQop1qdSXjPwkrrMkEJ8pySsEqA5UFy6lJ8k2kAsEXqlYYsFwmCBGPvYJBxKZBktzo5L9zrZHOL2lVAJm2nKrjUQlCFNykIUcbpnufc01dkQogkdDVXjtqrkXEWMVE6SD1+5ISqZQYpDxxlDXV4UYns/ydY+mKTapNB20rLDUv57k9xGJw+nKz8PdGifFLnZ/lHiXBKdIp4mhDiUFO0lNzEzNLnhrfoJhVRSpAC1VL+Uwolz9XzHKsopSY4ic9KcKbGIvux7SmKthMN7ruPvcmMSlSgcwKVS9C1ghi2XSFiE2DWJZFCsV2+xDnuNZIRz/mtSnR5Wyj70ZX4zjJzHzSvc8TiLr9SGFMVg3Q60cL7/TCNU6IVGr+1Pe4Kl65wpVul97CyrWccq5pMnwAcFo7Kfc7GUvKtMyqFlbaukrEoRLn1HmQ12OGAXAsrOzyUdphAM6xuGEAqnAFnOI1OMI1vVWUK1wp6GoA6QUc28qrWXKLi98RiYGyLIct96WP3LbPI0RsGiR96SMe7DhOtR9xzm/hSvZVhRLTZU6PxbAkUMKS0Q/gxYSy+5E1TRWxyRWuXAGtxtqGee/HJiLlOG56jnWV6gcAEZYw9hIGoNuPd81m5U3xRarznFMYE9frSJSPilKfm9PC6hifsPKSVli2wNUrM6UrZpujTuFKjZX7VtjMilkyfIDoxxV51Na1rH5cEdlGrMXZcWn+YFkapY+Y3rt8pHBTowRBEARBEALOnDlz0Lt3b5SVlaGmpgarV69O2XbevHnfelm/O8rKypLa2LaNqVOnomfPnujQoQNqa2vx73//29f3kHOWzTlz5uDee+9FfX09jjzySDz00EMYOnQo2XbevHkYP3580rnS0lLs3bs38dq2bUybNg2PPvootm3bhmOPPRYPP/wwvve977mem87e6Pyx08ON2eTgZa48N7q59aGsjGQ7rjWVsD6aGosTFtAyPtOy6SjGr28xLiEyftV2cUZCVcs80ltTubGkutZU1ZLqZizdmFNdayphbCMtovqWU+di7yWsWlQClZp4RY1PzaEpyrOmckIDuDtLcSys7Dkwy1jFlfwyT5n/jJACXYso1c7LWLoWVu74FG2vGbWzbCvLUOmjp556CpMnT8bcuXNRU1OD2bNno66uDuvXr0ePHj3IPuXl5Vi/fn3itWpRveeee/Dggw9i/vz5OPjgg3Hrrbeirq4O77//vkOYmiKnxGbQFz1dNjo7GV2zOIeusKTnYDA2xoMbnQN7rlScqOY8OIKU7McWmwbDDJhuenr89DGn7PfEEMacUISU7TTnSglok/GxQU3s8la5oP3XgDPsAKDjVzkCWt3JCqBFME94O9vo1nIFqCoFeiI49fiq8OaNFacy+HNYGKvjRagPO4Nkageh+++/HxMmTEgYzubOnYslS5bg8ccfx4033khfx7JQVVVF/ptt25g9ezZuueUW/PjHPwYAPPnkk6isrMSzzz6LMWPGuJ4jh5wSm0Ff9FaTtVc4MZVcOPGH9Bx8himMdYU3fx6cOfgblM5NoCL7GpwbK1nKw/VYyV4BFdkt42laqTmxvAZFNj2+OZENOEWpySQ0wClwucJbtxYt3Y8njDlbxupapFv6Js+NHF8zfphq00RthqAZf0sZOql+tBh3PiNibdYitscZx5tRMpAg1NzcjDVr1mDKlCmJc6FQCLW1tVi5cmXKfjt37sRBBx2EeDyOo48+GnfeeScOO+wwAMCGDRtQX1+P2traRPuKigrU1NRg5cqVIjZzYtHtb48A0V4pplzApPD2MIncHj8g6P7w8Ru/f9Bk5ccKp5+XHw4Gf5jk248Qdj/2Dwy99TFp/aevmb4v98cL1Y7yCLT9UbOzcQ+OviL9PH3Dg9hsbGxMOl1aWorS0lJH8y1btiAWi6GysjLpfGVlJT744APyEocccggef/xxDBw4ENu3b8d9992H4cOH47333sMBBxyA+vr6xBjqmK3/5gc5IzaDtOhNTU1oampKvE7cOHYcVEF2IQvk4+cQhPfkpcoCpxF3O0NdyOoM5sYy2tfLWnDG586fIywJKykJ9z2p1/QyV/UcNYdsjE+do/qqSixKbLNFqjXG+FQ/ouIBaywAtlKLlvV+AP57atOucVez898ziBc3enV1ddL5adOmYfr06UbmNWzYMAwbNizxevjw4Tj00EPxn//5n5g5c6aRa+iQM2JTB78WfdasWZgxY4aJKQqCIAiCUEBs2rQJ5eXlideUVRMAunXrhnA4jIaGhqTzDQ0NKcMDVYqLi3HUUUfhww8/BIBEv4aGBvTs2TNpzEGDBrl5G67IGbEZpEWfMmUKJk+enHjd2Njo+KWSFYJg+TJNNt5Tpq9p0prnt7XNZD9dK5eXdpS7UdcqZ9IaxrVymRw/RliTdK1tZHAe00LGacccy2FZo/pqWNHanYdyzmYGUNrkOaLag7NoKnMsxvhUm73O902NFSeqDcSUpCSyDXGOSmYi27U5tyPirC+bUTy40cvLy5PEZipKSkowePBgLF++HKNHjwYAxONxLF++HJMmTWJdMhaL4Z133sHpp58OADj44INRVVWF5cuXJ3ROY2MjVq1ahcsuu8zd+3FBzojNIC16qviKJHJN+AVhvn7PISiizuRYJt2xQXC9+i2wuGNxXKFeXKNqO0r4cV2jBsUa1Y4UcJzxNcUa2ZeaFymUqHYMgcUReSmvGU3bhhKI8SZnO44QixIJPFyxFmtWErtiVCkn57ONKxCdtzVV/sp5X0eIc9SuSG377iKun1GsEBByWcpIY7OUyZMnY+zYsRgyZAiGDh2K2bNnY9euXYlE6Ysuugj7778/Zs2aBQC47bbb8IMf/AD9+vXDtm3bcO+99+LTTz/FxRdf3DIFy8LVV1+N22+/Hd/73vcSVXh69eqV0FZ+kDNiE8iBRfczZlPEoDsK1cLnQWCxLHwm491MijXu+F7i6VTBw4xj0xeDPGubtjVPV/gR7biWNTQTAq4pvUBkCz9NyyBX+MWIrGm1HSXCos3ORy3HMki14wo/anxV6EWihMijbgFCDEaJckhxpSwBUenK0aZlLN482grQ3VSDDGKFLVguLZtu2wPAueeei6+++gpTp05FfX09Bg0ahKVLlyZyTTZu3IhQG9H7zTffYMKECaivr0fXrl0xePBgvPbaaxgwYECizfXXX49du3Zh4sSJ2LZtG374wx9i6dKlvtXYBHJMbAZ+0TOZIJQN8VkoYlD3feqKQW6mrUkxaFIgFoIYpNr57drNghikXKgcS6CuFTBVO1UQmhSDsQhRUidqTgxGm7lWQKbVTxFiVBu+gFPEJmlRdJwirYzpxCDAF5vcr9y2ffdk2wATsloOt300mDRpUkoP7ooVK5Je//rXv8avf/3rdsezLAu33XYbbrvtNq356GDZQa1HkkM0NjaioqIC27/+E8rLO7WcFDGYuX65FPvndyasaQHnZ6atoazUlO2Ywoz8CuSIP4OuY09xfgwBZzcR8yIKNpKWQdLyqIxPzIG2rLmPzWvpR7l7ubF/yhyYbmhyfOo3iGotpKyAGm7iVlTBxhVrnPEpMUjBviZjJyDuXCna/m7bbUcxsWkFtm/fzop/NEXr837Lg2ehvIO7Wp+NeyLodtWfMj7nICB7owuCIAiCIAi+kVNu9MBj2+1bqbJQQ9D3vkGwRhZKnKLu+NysY7+tkaykEi+u6fRJJWw3OsNqSVoLNbOCKTc0O+uYZdlkJM4AiDc5TpGuY9VVHKfc0GQ/bnKL0sZgVjM/QcW51rSFkjGWB2uk6q6m2ui6qzmWSKpf6nbpx6P+dHXGyrZTNlPbVeYLIjZNEo+bEZSZFn5e+nLer0n3NdHOU0Fpk3GWHLFmciyqnW7pGoDpmua6udMLPZZgZI5FtuO60TVFIzeDmRSNe9O70UlhqSkaacGo75pW+3IEY6rxOaLRW1Zz8jpyBGPKsTgCjukep+CIRl33NdVXx33tBu74ZN801/QwtBnCoZbDbZ8CRcRmJgmy5THTSTHMdnkXB2k6plK3xE0Q4iANJsBQ59hxkMykFT/jILkWS904SE7ijJvx/cyQBpxJMWQCjGYcJCXouGNlOynGVb8MWhlT9stuwri/hKFRZ9OXmeQEIjb9IteFpc/JNGQz3bqLXqyFnPH8TrCJUyKPKRAN1lPktGPVXEw1vs9i05m0wrMC6gpQL9nWqlC1yaQSnuWRk/3Mz9x2nGJZI7l1HnVrLJq0RnrJtqZKBaW7XsprBsAa6cXySI5nUFymm1s82250S8ONbokbXTCBTumjIAhL7nhBLehtWmz66fo2KSwBXrkfkwW9dYUl1c6LsGSU2jFu2WQU79Z3c5sTli3tcsnNnd4aycnSbplXeqHHtWJyx093Pe68Uo+X/xZK06I3I3jYQagQKdwAAkEQBEEQBMF3xLKZSfx2rReCFZN5TTaZ3pbQpBWTamfaje7ntoSa7nEAsIk1096WkJ1BznCjk+MTVjPFkml8D2mH5dHRJEsZ3tzYS+V6HqyFflvNdDO86bHMWTGDSk5aMSkyWNQ9HxCx6RdBEZZZgJ0d7uhoMKQgyEk9nLH4PjZGG56YJZN6nAFj2uNzyijZxIOau+0hlL5cYan2S9XXMV+qDTEWJeDU28KTsGQUH6d2q4kT52ybt9ONc1tCLwk2xDk1zjKgpYNSteP045LpTPBMxl3mMpnarjJfELFpkkzGbJoe36S1U8XLN63flk2T6NappGALOJvRRlNYehifNX8yw4MpBknLpqrgiDlwhSVD9FJzoIQfJeBU0UiXR+UJP1ogtv8aoMUsd9ccpzWPGp/qxysnxNnZJijZ4rr9TAqxIAhLL/PISUKhlsNtnwJFxGauEFBLZsatmNyxsrFFI4eAWDHZ43PEsqYAJa2Y1PiUOqCEpFrHkyNIXbRzzIMUKOmFX8s5RWwyBGOqc/TtqWSLM+fFtdyplkxS+JHnnONT6Jb78RtdF7nf9SyFLCBudFeI2MwkhWDFBPSfDLlkxaTgWvg4bUxaC8lrGhzfoBuda8Vk9eUKA2Y7dR7k7cp0Q/PCgvWsmNzxudACtP3XKcfKsBWTGt9LTCVFpoWwn4XSMzWPvEAsm64QsSmw0bZimoZjefTzeqnOcTA9V8YXPt9lbnBunHhMD2KQsopy+tE+VD0xyxGRANMgTYzFhROPqRuLCdDxmBy83E6ZF3CZvZ6QB4jYdEXhvnNBEARBEATBd8SyGUTy7Wd2Nlzhpvcg18XvLANT/VKOp5kgpAvHp5oC0t2u0YY7D371Lp473NlP/5p+YzS5xbBb2xQSnym0i6URsyk7CAkFT6YFYS7FYnLRfQCbTBAy2c80jphHDyKD8Z6447Nc8vD3ljU9tm71Ll0KPn5PKDzEje4KEZu5Sj6KNSFzmBQHWYj/DCqUFZNuZ/Ka5sby+3eJB8O1IAQLEZuuELGZ74go9YfAWAt9Vhq6dUKzQUDdsSaRP2dBCAhS+sgVIjYFQSgcCkCQCoKQAUKWhmWzcMVm4dp0BUEQBEEQBN8Ry6Yg6ED9ojW995vOPEy79x3vM8dTaKm9iSOZn4YgCDmOxGy6QsRmvmMRN7cEfgkUHAFNuYHIL9AsiFJVSAZYRFJ/ln72EwTBMBKz6QoRm7mKaRGpjue3IM1HEUx9kQT1Lfk9V0qAhjTjJTW/oC2iHzUDi7B22oRQ5Qg9i3yP/j5gqHk5zhHan/yIiHMmM8jDAXEICIJnxLLpChGbQYS0MAVVtQie4X7eLMsjJeKZ21VmGK4YJCG/tH1WRZHk8bmRFKQAjVlKGzeTc4/fv+1CxGcptTeFvEbEpitEbApsLCvsOMfaL93/J53znJfhHVYhD4PpxlRylYz6kPf0vilTV5zRhhAVaj8upOmLGJ+0PCr7gRPDkwKXazlVLaBRanzeUJalzs45B2osaqmz8bNBFZfUbjvsW5iwLAdBqJLzYlQz4ArvsPPrVHYVyiEsy4Llckcgt+3zicKV2YIgCIIgCILviNjMJFbIeeQjre4FHTeDCrVmmV5DzhyskPN9pwqKcxyW86DgtlOn/+0v8LYHa66GscJW0gHuwRiLP17IeVAw1toK2Y6DHIp1W9iOgzs+1c4xlkUcVL+w5TjCIefhmH/YeXjBuWaW4zBJFv4chFwn1Xd+e4fm82rOnDno3bs3ysrKUFNTg9WrV6ds++ijj+K4445D165d0bVrV9TW1jrajxs3zvFMGDFihNbcuOTcn1Q+LLoWnG9CrhDjtAuKqMt1NMRhSz/m04/zBcfuRwgs1peouX4W82BdkxCb5FhUO+pw9CVuWaao4/5+4Y2vORZxsJ+bDGFJCURazPL0v2N8QuBS11THpubgBRGpBYpboal5gzz11FOYPHkypk2bhjfffBNHHnkk6urq8OWXX5LtV6xYgfPOOw+vvvoqVq5cierqapx66qn4/PPPk9qNGDECmzdvThx/+MMftJaBS07FbLYu+ty5c1FTU4PZs2ejrq4O69evR48ePRztWxd9+PDhKCsrw913341TTz0V7733Hvbff/9EuxEjRuCJJ55IvC4tLdWbYFuBxI1JzPGs7EDEcZLBbUQ73WX1NBYj5pHsxkwf1i1NRAWHceIxqVhMzThOMguc2lucGcdpqX2p+Ezimuw4TnUexByopaDys1QrpcXNB6OsisT4tnKSvnV4IosSY2F1F1NyfJNxnM6xspHjFoQ4Ti/JWOoPAy+VAAo+KSxDpY/uv/9+TJgwAePHjwcAzJ07F0uWLMHjjz+OG2+80dF+wYIFSa9/97vf4U9/+hOWL1+Oiy66KHG+tLQUVVVVruejS06JzXxZdCOQ39rEty9X1HFKH+kKRO5cKTjWzSCLc8d7p/Yfp/pRX0qUqFM784SfRay/TYpGgwlCal+yHzHXYkosU/NX1qyYeHJTIpVoZ5PtlHkQcwgVOc9Rt3pY+ealRLYjhwhAqMh5kv5TTZ/tHgozhUGUELjqR0ksNZ0jRrvgnSTPjf7tRf2NcEpPcctT6QlcL193uhS88MsGOpZKl+2bm5uxZs0aTJkypc0QIdTW1mLlypWsMXbv3o1IJIJ999036fyKFSvQo0cPdO3aFSeddBJuv/127Lfffq7m54acEZs5t+heLHd+Wzs543uYg2rtJC2d5DcyMRjnmlx3exCsnVQ/dv1JjvmLasOzdrIEqK6wpPoS/djWzpL0QtJh6QQAQrjahGWKFKBqO6aYDVMC0bbStqGg/hxs4ltcHV+1dAIprLDM35iqhU+1dAK0tZMScBxRSrfhiUaO5ZELrRXSZ+bTY+lZO01mrFNC36S1M68FrwfLZmNjY9Lp0tJS0qO6ZcsWxGIxVFZWJp2vrKzEBx98wLrkDTfcgF69eqG2tjZxbsSIEfjJT36Cgw8+GB999BFuuukmnHbaaVi5ciXClHndADkjNoO06E1NTWhqakq8Vm+clJgUoFS/bPykZrwnbVd7qvHTXM/VWJkWoOTnxrRiktdkfJlTbai/fKKUj8MHnIVvDPLrnHqgK+KPXV6dK0ZijM8kTn1nOO/1MOMhTJZaZT4HnJZZrhWQaGUzLJvE0hQXO/tFiJJVtIBLnhsdx8l7T84107WIMl3mZAwoz0rKEaDckkm61k6TAjSvLa6tceJu+wCorq5OOj1t2jRMnz7d0MS+46677sKiRYuwYsUKlJWVJc6PGTMm8d9HHHEEBg4ciL59+2LFihU4+eSTjc8DyCGx6RWTiz5r1izMmDHD9zkLgiAIgpBfbNq0CeXl5YnXqfJEunXrhnA4jIaGhqTzDQ0NaUP/7rvvPtx11114+eWXMXDgwHbb9unTB926dcOHH34oYjNIiz5lyhRMnjw58bqxsdHxS4VNVgqe+xyzyRiLtHaSlkHmPDhw11VdM5MF4qk3Sb4dMpCTOMWI2eRaMal4SeU1aZPgutF1v20oS0hx+rlSkPOPOe9F2gLKMCtyEzVi1AegUOI8xejVMj7LLU+FLBBjMWI7i4ixohGOlTFV3/T9uNZI5+2pH7PJaUdbpM1ZO70UiNd1c4u1k8BDzGZ5eXmS2ExFSUkJBg8ejOXLl2P06NEAgHg8juXLl2PSpEkp+91zzz2444478OKLL2LIkCFpr/PZZ5/h66+/Rs+ePXnvQ4OcEZtBWvRU8RVa2egUPsdUsgSorrCk2jHH0hagpkskqXPjxpdyUDNDAIASHlwBWqSMFyVVJHGKGot4oijD6TtjnWN5+fbhzMNLQRv/dzhPfvOhvTwZScVn0qMnf75RZlhGESFwdaFEJEWEWG1OX3rFqB8myWOpQjZlP20ByhSRmgI0qJntQIEJ0Axlo0+ePBljx47FkCFDMHToUMyePRu7du1KJEpfdNFF2H///TFr1iwAwN13342pU6di4cKF6N27N+rr6wEAnTt3RufOnbFz507MmDEDZ511FqqqqvDRRx/h+uuvR79+/VBXV+d6flxyRmwCObbopi2WfmaLA4Q1LxiZ7SwBatL6ycVkrCclQMlYW+rxqrRTxWfKsaihGO2orGmiGesxQWeQaOMoYRQ2u/cfS8xyazYyHjqhiHP+RcTKUre6U9cTWfLEHKLNzsEoAaqWbooTGeuUwOUKUN1bg+ynWliJWFJKJEW1BagHKyZDgOrGegK8ZKMgl1ZK+nfbys5erYkJ+J+NDgDnnnsuvvrqK0ydOhX19fUYNGgQli5dmshf2bhxI0Jtxn344YfR3NyMs88+O2mc1rjQcDiMt99+G/Pnz8e2bdvQq1cvnHrqqZg5c6Z+2UcGOSU2c37R/a4tyRV1FCxrXhZEHUOAktZPCr+/mPxONqL+XB0FGpmWTa5rXW1HZXgTkK2UfYHtLGwEbZURJ6kHPEc0kv2Y74lK31ahamoSH5IVSR9yQYkw2pPvHCtGCEl1j2f61iHGIvaGpuqcUtfkwCpNRLShjJ3Unwi1jk5LqRcrKYWm5dTnZCMKdXxuQpsXUZo1LA2xqfm8nDRpUkoP7ooVK5Jef/LJJ+2O1aFDB7z44ota8/BCTolNIOCL3vaXjsk6koB+CSC/s90pTLrkGZBf417KLZmEUZlIeyxqPMpKSj1dmaLR0Y5rcqLWOpr8mZCfG1kyiVe6yZkizSz5FKLuFec5VRTZlOubO9em5L6U4Io3EcKSKg1FzKNY+cyjTJFHn3POn5OYT7Wh/sQp0au2o8QOtawRRjQCV9hQt3qEiEMtUiylpPAjrkn+AGCVeDIpXDnX45dbCkJMaMbIkBs9X/DZLCUIgiAIgiAUMjln2cwZTNe85FgeOf1SwUnq4fQD/I3/pNoRbbi/H+kEJGZnHaiCgdT7ptaMXJ94+jbk+nPjOBmfJdlPz01sRQlzBmGVoy2gepZN2trJcJFTbnTKpU1YHlULpV3sfN9U/J5NjU/NVWlXFHbOgRqLsmKGqOL7ymdCWh4JV3iMXGvCLaz05brfQ8TfV1gZn7JOUpY1OpHIiWqNpCxyPPc7oJuARLvM04/Pdb9zLad+JyAl/buNgojZzBdEbGYSXTc0hUn3OzWeyd2O2O+b18wkdJIHYwckaq6ctWAnbBHnONfkCFKA9l1yXPCU+510mVMu5pD7NoDD/Q7Qosi52xFzrmQWPnVOCQMIURXK07vfAULoRZgimEgaovZxt5V21J7wlNgsJoRxnLGfPBVjSYllzm5EABBXbkWr2dmP434HnIKH636nXfdOoRSNKgKO7R53wnHB67vfAcfOTKT72kstBr2SWxSBz0ZvW33GTZ8CRcRmEDFpFdWN4/QiZnV3/QlKUpJ6OeKc0aQktpWXYe3k3iaUsORYRak5kHU2iVVTRSPVhnrAMNtZUWUdKYsola3PFb0MMWsVET9MCKslFDFoEWKTFpHEvKhtM9WYUKINGRNKbOdpEQJU3QM+2kwJRkKsEfU/Y8R7UsWrakltGYu4JmVNVcQs1a9IvXfArxOq/uaIx4jM/CiRGU7GhDrPcZwLujGhfGGpFxNqsiao40rEzlYZRcSmK0RsmsSy2r+ZvJQ+4lgHTSYlmXTTe0kG4jTLgiDVd9P7HF7BEaSprum3m54TXkG50TnClWpHiUFKzHLd9KoFlGH9bDlHiDpVlFKClOmmt5oJa6ciGm1CWIZKCIG7lxLLxFoo7YqJ+cebnHOlrJjUvvBqCaYwIRBJgUuMr4pLjiAF+G76YmX+EUJYUkZwyo1OWnljahiAsx8lLE1aSbnjq6JU172fiiwUsEiNuNFdIWIzk5h2fasE1U3vt5WUbWU0KEq5dUIZQ9FWUua3u26xf11RanIsjiCl+nHbsa2YlChN35clSAGeKCUFKbGzESFmKSGpilLKYkm50e1iKraTqHWriEvVbQ8AISoOlQoDIOah7lpEWQsp4UdaFRVRyhGkAC1wqXCBWCT5XBGx1lFCUEeIXbAoa6r6nqgdnSgLZaatpNzxTcWcMja28pd0xqVUfQqUwpXZgiAIgiAIgu+IZTOImHRzU+SyS54az6RLHuDVCeXC2SeeORTLAurlczPpRs+0S57qy3XJs62pShwhw/oJZMIln94CyrF+AvoWUI71s6Ud8Z4I173qlqdd8umtgC3tQmnb6LrkAacFlBpL1yUPON3yXlzyaiY+N3Ne1wJq0iUPJFtAiyRmM6cQsZmrFKpLnjteLsWJsrfuTN/Ok0ueM7eguuSpviZd8gBvH3qm69vhkue6/DVFqUmXPOAUpSZd8i3tks8F1SUPOEUpO3GJOEddU01e8uKSL1KFMRFfquuSB5x/NiZd8ur4IRGbOYWITZPo3HxeLJQcgiBKTe+mpGJym05uX93vDL8/bqZwNSpKg2AlpdqZtJICZrPpFfHHri+qK0oNWkkBwkLZRJR8MhknWursR2bTM8o5cQQpoC9KybhOQqT6nk2vaSWl37eeKDVpJQWSRWkRdxc0v5AEIVeI2Mw2fotBk/Pw20pK4bfl1G8raUBrjnJFKa/mKHN9ODVHvVg2HWMRIszvbHrd0k0MQQp4EKUGraSAM5ve4rrMNUUpN5sezYQoVcbnCFJqDgBPlKqZ9ICLZCOGKOVm03NEKUeQAvqi1KSVFEgWpcXZFpti2XSFiM1cwe84TlNzMD2PbFhOTe6mZLLmKIXPH7lKTtUcpb7efM+mTx/rCcBXKylAiFKTVlKAmU1PCUs9Ucq2knJEKUOQAvqilGsl5YpS3Wz6TFtJAV7NUVPZ9OEMf/c5ELHpisJ954IgCIIgCILviGUznygElzyFyWQmwF/Xuun4Upbllze8LvpxogYTl6hrFkqBe44F1KRLHnBaMk0WuIfTApoNl7z/Be6zn00flAL3OhbQKNEno4hl0xUiNgsRccm3j59xogXskneI0qAmLqVqF9QC9xxRyi4DpbfrElneiV30npNNTwjLsFO1qDsnAeAVuCeEKzWW3wXuOaKU2t6TKnOkFpsHgFA0vUueFJak2EyfTU8J0jCxv3yMuMfSlXgqSrOdpf9oiM0CdiaL2DRJul86fgs4kxSqlRTgWZ24cEoHcfqlwqSY5YibwG4DCmjXHCUnQlwgqDVHKdSMC9NloNSao0ScH7UJtk1vgp12fNZe9QCsImq/d0UMUoKU3KicF8ep7mFP7V/P3ec+FCZEtVoGiogL5u6A5Ij/JNpYxLpSJWVDVNmkqCo2uTsz8URpqM255liWhZtko7tCxGYm8fKgDqpQLQQraRAsol7G41pJtd30muN7+HvQ3gaUWlbuWnDQtcx6sZKSl1TaUZm7VAowBdVOvS+o+4SwWFJChraKMsSm5tagFmUyoyy6lGgkxKajL7dYPuGmJ62pTUo2fTEhNonEJY7VksqIp/4cqHaUkFQ1Iz2W8xzVjrSmtmkXIeaUSSwrBMsi7rc0fQoVEZu5QhBEnS65ZCX1Mo8gZM5zx/LbTW+y6L0mgcmcJ9dM43qpCBNf435mzgO0eNUcixKgLDc91zLrEMaEwCLc9CDc9CAElnrOaqaK5afvB4C0pjqsrlT8KvGewkQ7S/nMKcFILiHlpqfWUTnHddOTsaNE/HDbvsXudJ55JGbTFSI28wmTFppsEJT567q+OXAfmlxMuulNClc/LaLcdkQbT256P29FL7Gqum56ysREzYN6SnCsohyLKECLWQXSIkq55H3GoqqPq+cokU256fdSMbPpbzJKrJFuekLghpR6QRZhEaXc9NQ16XPJ75PydFN/ulQURjqraNgiYokziYhNVxTuOxcEQRAEQRB8Ryyb+U5QXNi6BCF8IBvud11rpxfrcBB2UzJp7fQQq0rFYjmsnSZDBdgWY6Kv7u3Pcb8DKZKSFIsYVR6J3GGJaVtmZNOTI1GlmzjxpWRmPjMpSe1LJPmQLnMKKkEorIYUEBZp7viMDG4qS74ITisifasrllPi84iTcyCSkkhr53fzCHOT5fxCLJuuELEptBAEUadLUN3vXuYR1KQkk2I2yPGfDAFqNP7Ti4g3KUC5Y6lClXL/EpqIhHoKqX25TypCoajShu185ZZzYgg4dioLJRop172jHzOkQN12h7oe4d4PER8mlZmvissYsV5U9AZ1k5HJS23GpwRwRpFsdFeI2BT4BEXUcQjKXP2M/wT8LdME+CtAvcR/Gq2zqfcAIBOQqL3jyflrXZKPyQQkXTFLxnpSSUOMvpw2AB3/qVhdyc+NOMe2wnIgLaKUQCS2w1T7eZkH1wLKgJ6/emNQN4oZwRWyJWYzlxCxKXgjKKKOQxBCCvxer0JOQPJbgDLmT7rfOQKUGtqgMDZq/aTG445FzYOop+iAY/1M2U5Zf8Llzxag1FwNPkUtxj1MzUtbQhJ7knOhehLOcKKVGQFKbDCUWURsukLEpmAev615fpPp+eeSAPXbje7lmrpWXp9LQ7EEKNf6aTL+Nhvud7KvpvjgiFSAsJLyrM9sAaqKXrKkFOV+p4Qrsa0odU3OvIhdeRxjeRCbVF+HFZaYQ5iYrc20Ura9/UNe5m4CEZuuKNx3LgiCIAiCIPiOiE3Bf1p/AbY9hPbxe81ag9vbHpnGy3vk9M3Ge2S+J8sKJx2exjd5r1Brxr2menDXn2oXspSD0yZVO78PxhyKwsyjyHkw3rdFHcVhx4HikHKEHYdVHHIe5PhEu7CVfBBzoG6VcNh2HFbIeYSKko+sonu/aDBnzhz07t0bZWVlqKmpwerVq9tt//TTT6N///4oKyvDEUccgRdeeCHp323bxtSpU9GzZ0906NABtbW1+Pe//601Ny5ZeMJ4Ix8WXUBuCdBcmms2yKX14X75c94T933rro+XB5WuQPQbXaFKwRWgjn6USOUJV8uyHIe/wtVyCFJSRIYt1uEQjKo4TNW3JOQ8wsShillKGJNilhCWYeJos1RZ/5qxLI2/MffRtU899RQmT56MadOm4c0338SRRx6Juro6fPnll2T71157Deeddx5+8Ytf4K233sLo0aMxevRovPvuu4k299xzDx588EHMnTsXq1atQqdOnVBXV4e9e/dqL0c6sv1xuSJfFl3IA4LyoDaJSUug7lxNWjtzCNXS6dna6TeZFrNcKybZV1O4ctvpCF7TMK2dDnHIFdmprKfpDlLMOkUq1zCuCtKskqG/gfvvvx8TJkzA+PHjMWDAAMydOxcdO3bE448/TrZ/4IEHMGLECPzHf/wHDj30UMycORNHH300fvOb3wBoMbDNnj0bt9xyC3784x9j4MCBePLJJ/HFF1/g2Wef9bIi7eL6nb/66qsp/+0///M/PU0mHYFf9Bx90AWGXLKQCYWHn1ZMLn7/IOA+9U3O36goZcyVLSw1RaMXK6ZmP7aFUoWyflJjUXDG564hZywE7PGQAbHZ3NyMNWvWoLa2NnEuFAqhtrYWK1euJPusXLkyqT0A1NXVJdpv2LAB9fX1SW0qKipQU1OTckwTuP64WsVbJBJJnNuyZQtGjRqFG2+80ejk2pJPiy4IQoAwKaa4BOqpGUBMrg9H7HiwRjrc6kGGIehY/VIK0FDyQcFca8q1Hig8iM3Gxsako6mpibzEli1bEIvFUFlZmXS+srIS9fX1ZJ/6+vp227f+v5sxTaBl2Vy8eDGOOeYYvP/++1iyZAkOP/xwNDY2Yu3atT5MsYUgLXpTU5PjZhEEQRAEQUhHdXU1KioqEsesWbOyPSXfcV1nc/jw4Vi7di0uvfRSHH300YjH45g5cyauv/764P+qM8SsWbMwY8aMbE9DEJxWn1yraSoIfkPUehQEz1jfHm77ANi0aRPKy8sTp0tLS8nm3bp1QzgcRkNDQ9L5hoYGVFVVkX2qqqrabd/6/w0NDejZs2dSm0GDBrl6O27Q8k/861//whtvvIEDDjgARUVFWL9+PXbv3m16bkn4vejcMQFgypQp2L59e+LYtGlTyz/Y8e8OQcgEbe85ue/0iMedh9/I59Y+mV6fuO08mNi2nXQEmpidfOj2S9U3Fk8+KJhrbcctxxEk1M+dewBAeXl50pFKbJaUlGDw4MFYvnx54lw8Hsfy5csxbNgwss+wYcOS2gPAsmXLEu0PPvhgVFVVJbVpbGzEqlWrUo5pAtdi86677sKwYcNwyimn4N1338Xq1avx1ltvYeDAgb7GOQZp0UtLSx03i2AA9QEjD2EhSHDuTb/vYZPCmJprNuZv8pqc9eGuoa4ApcYnD2J8zX52zHmwBCLRhhyLgjM+dw2ZYjZIjwcbca3DLZMnT8ajjz6K+fPnY926dbjsssuwa9cujB8/HgBw0UUXYcqUKYn2v/zlL7F06VL86le/wgcffIDp06fjjTfewKRJkwC0xBhfffXVuP322/Hcc8/hnXfewUUXXYRevXph9OjRRtaGwrUb/YEHHsCzzz6L0047DQBw+OGHY/Xq1bjppptwwgknpAx0NcHkyZMxduxYDBkyBEOHDsXs2bMdi77//vsn4h9++ctf4vjjj8evfvUrjBw5EosWLcIbb7yBRx55BEDyon/ve9/DwQcfjFtvvdX3RRfygGx/0/mBSYue7vp4Wdcc/kzI/dPZnX1+39R9oV6TmoPJeVEihXu/ctp5GUs9lw23PXFNm5qHam1kris1FuccKVQJiyd1q9C3nUX+dzawv/2f2z5uOffcc/HVV19h6tSpqK+vx6BBg7B06dJErsnGjRsRapPUOHz4cCxcuBC33HILbrrpJnzve9/Ds88+i8MPPzzR5vrrr8euXbswceJEbNu2DT/84Q+xdOlSlJWVuZ4fF8t2afffsmULunXrRv7bX//6Vxx//PFGJpaK3/zmN7j33nsTi/7ggw+ipqYGAHDCCSegd+/emDdvXqL9008/jVtuuQWffPIJvve97+Gee+7B6aefnvh327Yxbdo0PPLII4lF/+1vf4vvf//77Dk1NjaioqIC27/5H5SXdzL2XvOaHBYGALIz/2yICg4m5+W32OS+R85Y3Lky2pFi08tcTYpBk2Izpm4knqIdR8BFibE44onaG506R4muGNVXmQd3rtx5NEeSXtoRok2EmCt1bm80bRsQ49t7iXNUu6aY0ib9HAAgHnUKx2jEeS7Wpt2O5ij6LHwF27dvz6h3sfV5//U3f3L9vG9s3IX9up6V8TkHAddiU3AiYjMNIiyDd81CsGKaFJbcdsyxWJZMjsjjXpPbT/eaXsaiBKgq2KixmALRf7GpnPMiLBl9SQHHFHWqKKXacMfizMNuovo5JUc04ozoIwVoc7LY7Pffy7MmNrd880ctsdmt69kFKTalwJsgCIIgCILgG65jNgUhLWLJDNb1gmDF9NLXi7vXT4JqxaTambRiUue8jMVxfXMslqnOqRZEbpwix4pJ9eXGlzKttQ5rIeX6bua61jWtpNxzTaqbXt+KGSPOxWPfnYtnO2bTtmG7/E4rZEeyiE3BGyIsg3U904Ir0y5yv4WlSZc51Y2b6KPO37SINykGOULVi5te1/XNdq0rD3jCVU0KS12B6MGNzhJ1hLBUYyVb+ukJV04/AHRspyIuScHIFJaxNDGbVPRFJslUglC+IGJT4JNLwjIoc80lcSnC0lVfbWFJjW96rrpiUNcCSvXTjamk+uoKS8Ah/rQtlqnaqeJSSehJ1Y9rLVTFJVtYkkk9mjGbmrGXpoQlNR41dibRKWWkU/ooXxCxKbQQFHGmQ1DmLsLSe9+AZosDhLj0O1ucO5bJbHHdpB6u6zgIST2aiTnsvlxhyczw9jOphytc40RFw1jM6Q5vm8AD0IKwrSv8u36Ea51o1/a2o7LVM4ltxzXc6AF5VmUBEZv5Tj7e3OL6Tk2QYypVApIt7rvr22Q2t+743JJDbNd3ete0JwHKsGyyyhBRYxkWmw5x5iFbnGVVbHa2iROWR3IejNJHOmWIEucUyybZppk4RwjLdFbRWDS7+c3iRneHZKMLgiAIgiAIviGWzXwi162YQZm/n/OQBB4z4+dbnGVQEnhMurn9jrPUtZJyrZhE7KXvCTwG4yypBJ54U3I7jnUydbv01k6ulVQnG53690wibnR3iNgMIvl2Qwbl/RSKO1x3rCC4ww3ONad25fEiLDkuco4rHAjurjwGi6KzXOGA0V15qAQbyh3u/648DDc3U2xyhKSOiGx//O/uxYhko+cUIjYzSVBEl0mC8J5ySURSBDnOkvPefY6pJJvpCsmg1q40mawDmN2Bh9PO5A48gK9i02SyTktfhthkx2emn4dusg7gFHCmknVSjc8VllS7KJHtHmlzronadz2DSDa6O0RsmsSOB0N8mSAo7yMb8wiChTIIItLL+JlO1vFwzUCUE8pHN3dQywkFxM0dhHJCfrq5U7WhRSSRqETcd0mWzVh2n1EtRd1dWjalqLtQUARBSIqIzFzfgLi5je6kw7mm6XI/flo2PZUJYrjDNUUkwIyXNOnmJs7RwoxYC0pIsvbr1rd2OtoxYiWBVAIxnLYdbY1kCsQYQ2xSYzGtkXHF2hiJOt83JTapc9StLpbN3EWy0QVBEARBEATfEMtmPhEEiyWQ+XkUSoY3RQBiKgNRp5I6F5QMb90dcrwk8DAysEmXnm69SZN1KkG4tRkWy5Z+jELmuhZLIMUWjcnj05ZBKqaSFxvJKZTOrVPpGIvoZxN7jpOub4bVMk4atwkrJmGljKSxdkap+z6DSIKQO0Rs5gpBEJKFmogjru/2CUL8JNXObze6Sde3rohM0dchJE3ukEO10xWRgFHXN6vguYcEHo6Q1BWRgL7rWzfDW1dEtozvOOUQkjoi8rtz7V8zyyGbUvrIJSI2s01Qbr5cqi2pEuRsbhXZNaf9NiatkV7GyrQ1UldEphrf711zmtJnYAfVGqkrIlvapReSuiIyZTtGzGa6MkGtqELSS/wkR0jqiMj22rUVmJGsWzbdWyoL164pYtMsQcxGzyVrZDZqS+r2DUI2N9GuYLO5qXa6IpLbLqhFy4nxdUUkwBOSuiIS4AlJrkubk4jDScIB9K2Rui7tVNd0ik19a6TDzU18tKTlkbgmR0jqiEjuNaPZzuzWsGwGTh9kEBGbuUqhikgv4wWhLJCHOo8Zz+amzgU1mxtgWh5Nik0P2dy6wlVz9xtWSSBAuyyQ0WxuwCEkg5rNDTj3+jZdWzKm3Bumsrlb5mXGpZ0Yn4jHdPbTt5wGSWxKzKY7RGwGkUIVkrksIlO1Y/QtWGtkPtaWzMctFBlC0kuCTaatkboikntNXRGZqp2f1khTLu1U1+SKSIp0ojeedbEppY/cIKWPBEEQBEEQBN8Qy2YmEYulP31NJ91w+mXD9W3SihkE17euxZLb13Qhcz9d37oWS4BltdS1WKbsm8OFzHUtlqnaqZZHXYsl4L/r2zlXZz8KXdc396uZmitF2z/LLNd0lx2EXCJi0yS27Z+gzDchGWTXN6NNILK5qXNSW7LdfkGtLalbEqilLyMRh1MSCNAuCxTk2pJ+ZnNT7fysLdny2pzr20v8JAVnrhTUb7ugI250d4jYDCJBEJZSW7J9ghA/SbXLRm1J7XI/HmpLEvNwWCMDUluSFHAG99OmhCRHzNKWTef4tBBLFo20xdIpLEkLImN/bqkt2X4bnWzuVG0oTFooucIy3VjZthKKZdMdIjYzSRBEpJfxgiAiue28uLl1d+Ux6fr2280dp4QeQ0hmoXSQdiKOz2LTixtaFaBssckWpend6PEm57RihDVSdycarhUzRt2KaoKQQTc3JfxolzavHcdaSFs2HaeMWiMpMunm9joW2bfN+FkusymWTZeI2PSLQtlCMd/c3B6umXdubqqdyQzvHHJzA4R1MOfd3PoC0U83NwDEIhyxmf9ubmoeueTmNiUsg0jcdi94sy2Qs4lkowuCIAiCIAi+IZZNk8Tj7i19QUjW4fYNSEylo4muKzzV+BzXNHcsv3fNUd3h3JjKLFg2He5w04XMNS2b2u5wL25upR0385y2klLu3mQ7gpcdbDjWTo4rvKWdnjucjLPUdIdzXOEt89Jzh3uJqaTgxFkGMaYyZb+AWyy5xGwLMdt576brU6jkjNjcunUrrrzySjz//PMIhUI466yz8MADD6Bz584p20+bNg0vvfQSNm7ciO7du2P06NGYOXMmKioqEu0sy/nh/+EPf8CYMWPMv4lsCMsgxFkanqtDXJoUltQ5LyEFnLGoJ7WuAPWUza25hzcnWYca34tLnrVrjl6yDsAUjcxC5hyXvMlkHYDIwGb2o9px3Oi6yTot80jvDidd35TYZLjDKeHHjdnkuMN1k3VS4WecZa4JyyAVdQ+iGz3IOilnxOYFF1yAzZs3Y9myZYhEIhg/fjwmTpyIhQsXku2/+OILfPHFF7jvvvswYMAAfPrpp7j00kvxxRdf4I9//GNS2yeeeAIjRoxIvO7SpYv3CWcjm9vLNXWTYnTbBTmBx884Sy9jcayWJoUlwErqYQlLaiyTVkwYtjxyrIpe4iwZCUKcLG2AlwlOtqHqTWalnBBj1xyDcZZetmPMdAIPt28QE3i84mUemSBuW4i7tFS6be+WIOsky86BXPx169ZhwIABeP311zFkyBAAwNKlS3H66afjs88+Q69evVjjPP300/jZz36GXbt2oaioRWdbloXFixdj9OjR2vNrbGxERUUFtjcsQnl5R3edRVh+1ywIwpK6ZlASeDii0aSwJOahLSypc4YTeLQzvDVLDJlM4DEpLKl2JoUl4BSXJoVly/hKmwJN4KHbiLAEgD12FFfgb9i+fTvKy8vNTSYNrc/7Nzc+gM7lHVz13dm4B0cf+Etf5hx0nZQTls2VK1eiS5cuiQUEgNraWoRCIaxatQpnnnkma5zWD7h1AVu54oorcPHFF6NPnz649NJLMX78eNJs3EpTUxOamr6rG9LY2OjyHWmQZ+5wo3GWJssQmUa7NJGmaGSPlYXalZz4T5O1K7k78FDjM2pXsl3yxDln7UdzwpIa36SwpPpyhSUp6jRrV5LCj2G1NCksU12TQ6G6w4NuseTixbKpaobS0lKUlpZ6mk/QdJJKTojN+vp69OjRI+lcUVER9t13X9TX17PG2LJlC2bOnImJEycmnb/ttttw0kknoWPHjnjppZdw+eWXY+fOnbjqqqtSjjVr1izMmDHD/RsRBEEQBKGgqa6uTno9bdo0TJ8+3dOYQdNJKlkVmzfeeCPuvvvudtusW7fO83UaGxsxcuRIDBgwwPGB3nrrrYn/Puqoo7Br1y7ce++97S7ilClTMHny5KTx1ZuHxG/LWiFYMalzpkMKTO7Kw2nDzRbntPOw3zireDp7roxEH24/4pq0BVEdn5grd49wRjtuP9Z+3cyddcjbjmGNZN/CTAuo0zVNjeU8x0/E0U26yXy2OG9ezL5ixcwZvCQIbdq0KcmN3p5VM1d1kkpWxea1116LcePGtdumT58+qKqqwpdffpl0PhqNYuvWraiqqmq3/44dOzBixAjss88+WLx4MYqLi9ttX1NTg5kzZ6KpqSnlDcAyeQelDFEWYBdUd3Q0+J6ysSsPK0HIg1jjnPMkZjM8PrcMEeX6ptSBQwwS43OFK6cdMQeqDFE8nl7A2VQb5g45rCIFTJc5u/5/RBWDTOFHnnOObzLpxiRe4jGdY2U/9tLPbPF8x0vpo/LycnbMZq7qJJWsis3u3buje/fuadsNGzYM27Ztw5o1azB48GAAwCuvvIJ4PI6ampqU/RobG1FXV4fS0lI899xzKCsrS3uttWvXomvXrp7jJ9j4be3MQqKPAy9PBb8tm7pkw9pJ9lVjNnnCj8wL9Nka6djukXpIM8+RAlFtR41FxGJyhCsAx3uiRSoz5pH1e4YSg3qJPuRYxIOSnj9lCWTEQXKN7AyxZjKDPCixmEJuE//2cNvHLfmik3IiZvPQQw/FiBEjMGHCBMydOxeRSASTJk3CmDFjEhlWn3/+OU4++WQ8+eSTGDp0KBobG3Hqqadi9+7d+P3vf4/GxsZEUG737t0RDofx/PPPo6GhAT/4wQ9QVlaGZcuW4c4778R1112nN1E77p/wyfQ3mJcMcoPja+N3uSL2PBgJPGQ/D9ZO7bEYAteolZQQkdT4pDpIPw/SOsk9R4pedfy0U/i2XXpRRwlSk0UWPBVBYN1iPCsmhUkXtt/kshUTEBe5SeLQSBCCf6WPgq6TckJsAsCCBQswadIknHzyyYlipQ8++GDi3yORCNavX4/du3cDAN58802sWrUKANCvX7+ksTZs2IDevXujuLgYc+bMwTXXXAPbttGvXz/cf//9mDBhgj9vIpesmF4wuV96NqyWHHStmBS6VkyyDdOKycWkG12Fa8Xkrg/HskkpG6odQ5SSIpJ4+HB+93CtmNT4FOp4XqyY5PiqwZsbR2hQrGWjDiaFWDILkyAWdQ+yTsqJOptBJ1Fns35h+3U2pTZm++Prik2TMZXUOc42kSnbMSybXnbS4cRBchJ/Uo3P2fVHszg7vdsOFQfJLH3E2QKSW3uziZpb8rl4k6MJ6dImt4V0lA7S382H2oBK3WLSS9F16lZR61lyE390t4rkbBOZ6ppSdL1N3zyybGa7zuZfP/4NOu/jss7mjj04vs+kjM85COSMZVPIPoFxmfu5ww81nsn5e3Gjk+00v/B1XeteMsh1XaOUtZNjAeVaSbmue8d4PNc3BcdC6cVNbxJu8o+zDW/8oCT/OOYQAJe5F/JlD/KgInuju0PEZhDJxjdtpt3VQXWPm0Y3ptLkWEF4cgMOUceOzyTHSt+O637nimDebxyeuzrd2G4wWVBBFy/xmUL7ZNuCGJQ5BA1bw41eyH5kEZu5SlDiM4XvyKUvZN1EHy8JQrroxmzqtskA2dillgO3XmYQ8LteJn1NY0MJOU4Q90YPMiI2BW+Y/PbNJWEchGQgL/MIKOxkIJ/hzIObrKMtLHP8wWRU5AVU8BYKoRCV+BaMv9VsEbPd/3YNyG/drCBi0y9ySTgJ7ZNDn2VO5ft5+eZlWTFz6HPzOe7SNH6Lv0IXMkLwEcumO0RsCoJQ2HBjO3NMEAqC4B9BLH0UZJz1NgRBEARBEATBEGLZFAShsCHi0SiskFpaSSydglCoSMymO0Rs+oVFGI1zKPZPaAP1WWrtcus/lkWU3snCPAJBmPjcqD3PA4BDyAKBFrOhcPJr03lqakJKULevFAoXidl0h4hNITgUgkCnrGjUW+S2y2XCxHuMZKBvhiF/qzASbCyLEli587AKEZ8RVSBeEHIRidl0h4jNfMdvARcixs+zcjwk5PvOQn0WdR7ctafELPWeVItbiBif0w8ARy1bxLzY38+quKHEJ1Ok0vNInokXMUgKUB/7eUG1YgKZ/xOn5kCvdfrPRKykggniGjsIiWVTyD1Mi0h1PL8tivloxdS1RnIFu67ADcoPAmV92MKSK4yVxeaObxEC1CYFKDU5tY3zCrRlU12L9GOnvmb6c5QuJn8jEOc4FaS4dRipyIZcqqGZjT8lR0gB0zxG/kjIobUOOmLZdIdkowuCIAiCIAi+IZbNIBIUS5SCZTl/Kts2J/gsG658L+MxxuLOn7TAqeMTa8i1YqoWJdI0xRyfcn2rbnPdfnBaECnrIW36IsZnuL5twmJJWjspyyljHtRScGvqqy548uMm5kUttU2GLJhz14Upq6UavUG4ptm3MPE5BaGoOzkvhguebeUlLI/Un6oQTCRByB0iNjNJPrqOKXTjCCk4/sVsrCE1L9Kzy3FzG0waIp7wFrH+7Ee5Oh5TWNIiW1kLKn6SKJxuUVnZlOtb7UsVYeeKWWKtVaFqUYKaXApK4ConiNuEdMkT41MCOqT0pf5EyFuYKfxUFy11C5CRIMykIeftoxefSbfjjcUOfzb4dce7nr44d1QR8CBuC30LSyl95A4Rm7kC5xuNK2Y57ZhjGbV2UugKSZPWTlI36SbieIjPJMdXvr24VlLKBEcqBrWNZmIR1Y6MbySEGSUQGQKXLCdEzp8rZtXxnXOgLZTp4zjpJBxK8DrbUW9TtTyScZ0G4zgp4UHfwrw4TkdST0BEjFg720c3vjQXaYnZdGvZ9GkyOYCITZNYoe++1bkiyW9rp64A1RWpFF7CAnQzJ7KR4EThcMkzxSCljDnWTrJAOdPaSc1D/YaIUsMzhWtR8pOUtC9R38aU6CpOb6Ekt5csYVo2i6m6l0o2OjEHal3pTOpkqLmGi5xzIP8sKd3N+Gb3Yu0sVl5HCBFJ3QLUp06vT/J4qnhONRbP2qlrEeV9bfHz/cxZBnXHkiQifSRByB0iNv3Ci4jk9A1oXCfgtHayLJ2AWeHNdnPrDa89FvkE1rRistswLaecgEMvbnRGGSVSRBLYlDVJEYikzYHsx1x/pS8pWShrFdHSViwiXGEZZn5jq+NTcZ3csWiU+VO/jZiWTY4opdv460Y3WVrJy9e1au3kWjqDIEDz2dUubnR3SDa6IAiCIAiC4Bti2cwkJq2dVD/uz2eDMZucdmRcJ2kZZI7PwUu2uJ+xnVyXOTUYxypK/UVTrm+iHcuOQ7rH/f0aYVsoFcgWMee9yI66inEqC1A+YadZKMyxMHlwVDjDEbhWOl14qxghqggUqT55ANFI8tw4cZ0t+FvUnQop4FhOudZOsq/yWWYjrlOsnU5i0LBs+jKT3EDEpkkyGbPJ7acrQL3EbDLGyooA5cIpfcRF9VWSvjNKDTJFnToeU1jqClDyuzUL4RtWGXGSFAIMyH6MBaJ8xxRU8owyvkXu2e68zynXdLTZea6oJHm8WJQXyxtkAapCiXF6x9Lk91RUTIkdYnxCGdCiSC8mVFeAekksooSeY15ZKBqfiwLU1ojZ5JZGy0dEbPqF3zGbXvqxaj+mb2Ia0mbA0ZXcuWonOHm4JmcsUuVRnxsVGMd4oFNtqHsgSggsNcM7yutHJhtxaoJy+qVoZ6nqgCpfFHH2symhR11TaWdFnE9WqmanTbVz1Bx1zqF4r3Nd48zaoaq4pCxy7NJH1O2pCNwYceuQ4zOz1tVMfLWUE+C0fqYaK6ooUKqQBCWSoj7HhNICi+jKGp/TzwlXkEppJScSs+kOEZuZxEu5H916k7rX5Jbe8TvpmyMQiZ/1ZFKSbkgBNwWYI1xJEwpzLOqc+pTnWrKZtSUd7YoIc0aUcE1T11Q+J7uI+RlFKSVAqBulncXsR7WziaQhVRBSItIqI7L8CdFoK+3INqXEujY5rxmixKwy1yghsouihIBrpixwRKa8Yh2MEeNTAjROWFjVsahrFhGfByU2I8S5IiVxjOpHlZkKUe0ItaD2pQQRtYZUO1IIM6re0VbS9GKW/rqgLKfEDxqWkNRXV3Thju/mEbKtrBhFWhGx6Q4RmyYJhdoXaX6X+8kGrKkajCWl2pFueqqwIG/9HUKV+7lxBCJzDuRYpFBV/oS5IpUaixKSajuyNBFTICrtSDHIHMumnnSOkAI9kQoAVhlxTu1LzYuwUNp7qXbKObIfIUCpdoy+IapNk3P8EkJgUUJVFY3RZuf4lMCi3PmUAI1GrLRtqLFKyWsmz58Sm1RheaodLVSV14SioEQwJXA5QpUWqTw3PeerjBJ53FLAzvhSXZEKpBOqIRsiNnMIEZsmsaz2Yza51kJdvPzhcWI2Of24famIf7Ifx+XvQcARWKrvngxty4Jw5VpAOW2oeXCEKnesOGXWMidcScup2s6kcKX6Mq2kLAsrNS9SpFJik9GXYf1s6ed8T2FSqCaPR80h3uQ4hRhl4aPEpmJh5QpXjoWVK1zjREIY1U4VoLSw1BOu1HhcEUkL1/RtvAhX1oZpXKdZmq+2iG2lCtLNCFJn0x05ZEITBEEQBEEQcg2xbPqF365w3y2nPu/6YzK+lK6LwrsmBw/WSEv9E+PWsyHGYhXH51osqQo9Ji2bmR6L6st001OF2El3u9qO6aa3qRRUhmWTtpLywgDUc6T1k4j/RDPTAqq44MkkK8q1TrjuOSEE1Pjc+FLVGklbP/Xc+1Q7rpWUHp9wwUdVy6ajSQoraXprKp0sRfQj4ntpd7ve+BRkmEGbvkUut4o0jbjR3ZEzYnPr1q248sor8fzzzyMUCuGss87CAw88gM6dO6fsc8IJJ+Cvf/1r0rlLLrkEc+fOTbzeuHEjLrvsMrz66qvo3Lkzxo4di1mzZqFIp3Zg29JHpuHUwfQbo5cUMZuA+b4tVua82ZACX+NXvcSXcq7JceV7mQdTpLLELNnGnJjVdu+nnEd6MWsyDpWKLy3SFLPcOXDFpkkxS4llVVR7EbOO32PMJCuumFXHJ0MKDIUBWHEbaHT2yxRxDbHptxs9yDopZ8TmBRdcgM2bN2PZsmWIRCIYP348Jk6ciIULF7bbb8KECbjtttsSrzt27Jj471gshpEjR6KqqgqvvfYaNm/ejIsuugjFxcW48847fXsvWmTDUprpObDjODl4CAzK9H7s3KQhk9fUFbhciyiF0o69jSlXIHLaeBG4JsWyrpXX5/hVroBW41xZc4AHgUsmY+lZazmW2lTtSjgCVzOJC+AJXG/CNZS2DbeKgJp4RY1nE9ZJjvW2ZSznPNr2LYnZwGZnm0wRxJjNIOuknBCb69atw9KlS/H6669jyJAhAICHHnoIp59+Ou677z706tUrZd+OHTuiqqqK/LeXXnoJ77//Pl5++WVUVlZi0KBBmDlzJm644QZMnz4dJSUl7ibqp2WTg7YwQ25lwJuEY7Vkr6vmGposis7eOclDX0P9SIHiZXxdAUo1MxmyYFIYByRkwfHZcSsLGAxZoO4fMtmLEbLAL7mlmezFFKCcMAauWKbCGFSBzp4XmQDmbKcKYS8VCdL13RGJAe85p5opguZGD7pOygmxuXLlSnTp0iWxgABQW1uLUCiEVatW4cwzz0zZd8GCBfj973+PqqoqjBo1CrfeemtCta9cuRJHHHEEKisrE+3r6upw2WWX4b333sNRRx1FjtnU1ISmpu9SLRsbs2jLb0s+CkaT70nb8pjhHYvc4OUHhimyMQefBbp2NFiGBTsA3lp4Gd9vEW9w/mSoCUd4UxgU9uQPK5PhG7Qfmjhnp2/j5YeD0o5d6UFjHp12NQPPvuxskyG8iE1VM5SWlqK0tNTTfIKmk1RyQmzW19ejR48eSeeKioqw7777or6+PmW/888/HwcddBB69eqFt99+GzfccAPWr1+PZ555JjFu2wUEkHjd3rizZs3CjBkzdN+OIAiCIAgFSnV1ddLradOmYfr06Z7GDJpOUsmq2Lzxxhtx9913t9tm3bp12uNPnDgx8d9HHHEEevbsiZNPPhkfffQR+vbtqz3ulClTMHny5MTrxsbGlpsn2250oX38rnNaoJAJKgHA9rnis21wH3HbsHVYd27ceXDG564/55rc90ONpTtX7n2t9qWux52X7lpw5++YK6MNQBsxY0Q2uMNwSrShrIFUu3TX3Nm4x9kgg8Rsi1yDdH0AYNOmTSgvL0+cb8+qmas6SSWrYvPaa6/FuHHj2m3Tp08fVFVV4csvv0w6H41GsXXr1pRxBhQ1NTUAgA8//BB9+/ZFVVUVVq9endSmoaEBANodN6XJ24IH/5t3gvrQ94JRwcBcHpMiwjF2jokK3etxPjeTwoY7HvdBTY+f3Jd8wGsKDbYY8VlURKhyM8Tyq+0osUCO5RwKUaKd7vhUuKE6Pt2PSGQhzhFhkI5rUm2I+vzkNZuJjejV8anEf/411bGoBB5ia1AiGYhuZ6VtE6F2pKLWP801o3t2O/49k3hJECovL08Sm+2RqzpJJatis3v37ujevXvadsOGDcO2bduwZs0aDB48GADwyiuvIB6PJxaGw9q1awEAPXv2TIx7xx134Msvv0yYn5ctW4by8nIMGDDA5btp+UI3Ifj8tsjw5uCvcDUpurzMNd9El5exeHNlCiDNuXqyCinziBPxgV4sORwBRFltSDFl64kpaizKuhKNJ5cN8CKwqHaODYqoXBeGMKPGaumrCkSiH69kJ/Yw5kqJouZmZ+kFkwKLP376DHKuGAxFk9+8RdywRVHnAlHlu4qIhQwp7ah+ahsACBHzCBOly4ra9I00E1tUZZBMJQjli07KiZjNQw89FCNGjMCECRMwd+5cRCIRTJo0CWPGjElkWH3++ec4+eST8eSTT2Lo0KH46KOPsHDhQpx++unYb7/98Pbbb+Oaa67Bj370IwwcOBAAcOqpp2LAgAG48MILcc8996C+vh633HILrrjiCq1gXRvxdh/YIuD8uaZJdx13HkFw11HX9FvAscUmox3Xisa2mil9VUEHANE4UQPRdn4NckQXx4pG9aPOca1oHGsV1ZdvWSOuyRB15PhMq1lzEyG6HPuZO8dqYvSjrsm30vHaqTcetb2nFwFXpIguql8HpoAzKQbpvum/Y7njU7S9ZiSyl9XHL4JWZzPoOiknxCbQki01adIknHzyyYlipQ8++GDi3yORCNavX4/du1tM6yUlJXj55Zcxe/Zs7Nq1C9XV1TjrrLNwyy23JPqEw2H8+c9/xmWXXYZhw4ahU6dOGDt2bFK9KTfY3/4PKBx3Kd03+9Y8L5ZHjvjTdaGaFH4tfRmuXYMuVL47k3CVKWn3XGGm60LVFX6p+qqCimu544hG0nrIq2aj7S4lRR5pgUsvuijLHVescSx8pJuVYbkDnOKPEjthwjddRljWwsT46ngckQfoCz1dkUeNryPy3IzPRaefTXwWmSQed18Uw2QRDYog6yTLzsdAvwzT2NiIiooKbN22GOXlnQCI2EzbRsRmyjYp2+WU2Ewfrydis/05iNj8Dv/FZnoRCYjYdDs+F51+kchu/HHxBGzfvp0d/2iC1uf9f7zyKEo7d0zfoQ1NO3fj3pMyP+cgkDOWzVwgbsfIODG3SBxh23P+xhFqu8NzPo7Q+YCn3M7qeF7i/Bxubg/Cr4kST6rA0hR+qfruVa5Jik1iLDomMfl1E7lzjJ7LuaVv8ni6LmduO92YQSCV0Es+V0K0KctxMcgVWBxBaFIM6gpGgC9edWn7PmPETkSZJGhF3YOOiE2DtCQIuROKIhDb7xeERBO+GCR22VC60skclEAkHt6M5JCoZnxjqnYmM3k5MY8cYcbtqyaZpOxnMNGEKwbVdlRMItWPFI0MgeglJpESiGpiSRFRgLskEnGOFQCB6CX+kCXWmOOTfTNoGQTMi0Pu+8wHYtAQm77MJDcQsWmQtjGb9L+bS2QxWQ6GGk+3HAzVV7eeHLevlxIxpKXRYY10zosWjVTyidqGmU2smZ3MzobWFI2czOSWds5zvKQVqh/zmrH0bSjhqmtV5FgUAZ5V0Vu5mfTtuBbFMmLrSEqQqAkvXCsglQBjUiBSqKLRVIJK6uv5KyzJsQyKxqAIxnTrY3L9dAji3uhBRsSmQdJmoxsUWNx+dDu92D/uNYMQp8ixMgI80cgVcBzRyLUo6lsL9d3c5JbIyjnVupeqX6bL2VDtdK2MAM/SqGtlpNqZdkOXxmNp2+haGQGnuNHtB3AFnP74OtdLNb6f/YD8E43ZFoR+Im50d4jYNIhtxxMiKhuWx6CU0FGJE84D/0vo6Cea8ErocOsdctzQjlM5X0KHk9zCnUMTlaDCcU0zLZa+l9Chilgr60MKxqjTDa1b75AUgznumqYIguUxH13T+SwadRGx6Q7Zv08QBEEQBEHwDbFsGiRdnc2gluOh2mWjHE+MuCa9/256a6HJ7fO8uLk5MZtNMWLHEU03txcr5h4iw9tRvNugm9vLri3NTemtneSuLQazsvlubudiFDMKgZt0c7PH0rQgms62NmmhNNUPEDe38B1BK+oedERsGiRd6SMvNRCdbTzEQXLEpoe5qmtgMsGmZTxGnCIpUp1j6bqmTSbYmMzK5sRKpmpHle3RjdnkiDpurGRQdopRhWRxhFn0O88SbLjjUwShRE8my/NkChGNmUcShNwhYtMgbROEdEVkS7v8SrrhWPxSt0svJE2W9qHaZWMPad1Mba6I1M3UDnLtR2c5ISJWkviMKGskR0hyRCTAE4O6IjLVPDiWR4qgJN0EQUgGwfIIiJAMKhKz6Q4RmwZJb9k0Z430Uu7HrDVSz3XspdwPpzh4UMv96GZup7xmAZT74e50owpJWkQ6FygI1shc2Y/azfhe2jn6iYgUAkY8bpE/XtP1KVREbBpEJ2YzCIXGAaf441oGObGRJl3agL+FxlvOtf8a0C8BlEuFxgGnkAxyoXFVSHK3JdS1RlICSDem0qRLO9XcVILg0gbyT0iKiCwM4jELcSLOPV2fQkXEpkHidjwh7ky6tKl23hJsHKdY2x76nWBDXzP9+CYTbAD9Ej2cBBsvLm2qnSrETCbYAM4km6wk2MScb5wSiKw9sD3Ug1TbccsQUeRygo2XvvkmLAERl4WKWDbdIaWPBEEQBEEQBN8Qy6ZB7DYxm17iJ6WQebq+Usg8MX5OFzLnWSylkHn745N9M1wWSAqZC4WGWDbdIWLTIOmy0Tn7cLe0c46dyxnepmMqcznDWzcxB0glNtPXluQKRI4LniMiAfpBXawstm5iDsAToLoiMtU8cjnDW2Iqv0NEpGACEZvuELFpkLbZ6LkUU0n1NRlTSY1Fx1k6xzIZU0kLUOI9aWZ451tMJdWOjqnkCcRMx1SaLFreMn76+RdqhjdFEIQlIOJS8Ac77j5ByBaxKZigrdj0ks3NsVB6KR3EEZtessWblD9Arptb10JJZ4EzXeuMrG/askmMz8jK5lgngRS733Asjwbd3IDT1c0ViBw3t657PFXfILi5KbjWVA5BEZZBEJIiIoVsIpZNd4jYNEgkbiXEl8nSQYDTQsl1OXPHd2wlyBaW6S2U5t3car/01slU43PmYXJbRY51MlU7XTHLz/pOb6E0ua2iyfhJajwvwrJQ3dwUIiwFwYmITXdINrogCIIgCILgG2LZNEg0biUsgiaTdQB9N7fJDG/dnXS4LnPdnXRMbscIEHGWmtsxAs6sb5PbMVJ9udsxUhZKToyjbgF0wOkiN5ms09JXdtJxQxAsloBYLYXcJB5vOdz2KVREbBokms6N7qE0US6XE+L304uzNF1OyJnUwxOIJrdoLIRyQqaFpZQTSo0IS0Ewi+wg5A4RmwZJitkk/t1knKUXYZnpOEvdLPBU43Oy0YMaZ2l6i0Y1zpKTBQ74H2epmwkucZbuCYqQVBFhKeQzErPpDhGbBonCSrjATe79TbUzufc3dU5372/A6fo2ufc31c7k3t9UX929v6l2noQlw/XtpU5lpl3fuVSn0k07R788FJYiJIVCR8SmO0RsGkQnGz0IMZWAl9qVuRNTyS0xxHF9c62kjphKsgB65mMqdUsMSUxlmn4iLAWhIIjFLIRcusVjBexGl2x0QRAEQRAEwTfEsmmQWJts9KAk8PBd35x+egk83DnoFkrnJvAEoVC6l/hJ2Q88TV9J4PGMWDEFgUfc1nCjE/kYhYKITYO0daObFJZUX5NF0am+pEuecAHr7vCjmxkOOMWlbmZ4qr6sDHL2DjzJb960m5sTsxmUQunpxk41Ptk3w8ISyD9xKcJSEPSxNWI2C3m7ypxxo2/duhUXXHABysvL0aVLF/ziF7/Azp07U7b/5JNPYFkWeTz99NOJdtS/L1q0SGuOkfh3R2vNzbZHhHlQfffGkPZoe/3Wg9OvpW/yHPZG4Tia486DM4+m5pDjiEadR3NT2Hk0h5yH0qaJOKjxoxHiYMyD6hePwnGEI3HHURRNPsg2xBGO8o5QzE4+4s7DIg6qnWOsGNUvThzOsSjSjZ1KgHHHN9WPmqsXocl9n37iZS0EQXDSmiDk9vCTIOuknLFsXnDBBdi8eTOWLVuGSCSC8ePHY+LEiVi4cCHZvrq6Gps3b04698gjj+Dee+/FaaedlnT+iSeewIgRIxKvu3TpojXHtkXddfcRb2nnHJvn5k7fr6Vd+rJDuls7ApSbm5vNzXOjq+5w3QLoqfqq7XS3dgSclk3dkkMAz2rpd6F0ttDTtFr6ncBDjpVnFktArJaC4DdBrLMZZJ2UE2Jz3bp1WLp0KV5//XUMGTIEAPDQQw/h9NNPx3333YdevXo5+oTDYVRVVSWdW7x4Mc455xx07tw56XyXLl0cbXVoW9SdU8sy1Tld17TJepa6whJwiktOLUuAX8/SUbvSw648nLJDlqawBJgxlZrCkhrPm2tar84mRRCEZSHsGQ6IsBSEbBC00kdB10k5ITZXrlyJLl26JBYQAGpraxEKhbBq1SqceeaZacdYs2YN1q5dizlz5jj+7YorrsDFF1+MPn364NJLL8X48eNhWalviqamJjQ1NSVeNzY2tpyPWQh/KzJNCkuq7x6mRVS3niU5V804SJPCkuprvFA6xxqpG1PJFJa6otSksEw1N5V8LJROEQRxKcJSEIJBy3aVbsVmy/+3aoZWSktLUVpa6mk+QdNJKjkhNuvr69GjR4+kc0VFRdh3331RX1/PGuOxxx7DoYceiuHDhyedv+2223DSSSehY8eOeOmll3D55Zdj586duOqqq1KONWvWLMyYMcP9GxEEQRAEoaCprq5Oej1t2jRMnz7d05hB00kqWRWbN954I+6+++5226xbt87zdfbs2YOFCxfi1ltvdfxb23NHHXUUdu3ahXvvvbfdRZwyZQomT56ceN3Y2Ijq6upEQgygvxtO6r65uysP1S/Iu/I4XN8+b/foxbXOc1eLFdMtYsUUBKE9vLjRN23ahPLy8sT59qyauaqTVLIqNq+99lqMGzeu3TZ9+vRBVVUVvvzyy6Tz0WgUW7duZcUQ/PGPf8Tu3btx0UUXpW1bU1ODmTNnoqmpKeUNkMrkHbW/E20mhSXVzqSwBHhxkEHY7pFqZ1JYUu28bPeoiqJcEpb8a0rspVdEWApCbuElQai8vDxJbLZHruoklayKze7du6N79+5p2w0bNgzbtm3DmjVrMHjwYADAK6+8gng8jpqamrT9H3vsMZxxxhmsa61duxZdu3bVip/YEwPwrbjT3Uc8ZV9GoXRdYQk4BRxXWHLqWZoUllQ73X3EAd5e4l4yvHVjNnVjL0VYtk8QhCUg4lIQcp1MFXXPF52UEzGbhx56KEaMGIEJEyZg7ty5iEQimDRpEsaMGZPIsPr8889x8skn48knn8TQoUMTfT/88EP87W9/wwsvvOAY9/nnn0dDQwN+8IMfoKysDMuWLcOdd96J6667TmuekThQ9O2znyss6Z106LHbYlJYUu10hWVLX80deDTd4SaFJcDbD5wqum5yBx6KTGeLU+PrtiH7ibAUBCFHCVo2etB1Uk6ITQBYsGABJk2ahJNPPhmhUAhnnXUWHnzwwcS/RyIRrF+/Hrt3707q9/jjj+OAAw7Aqaee6hizuLgYc+bMwTXXXAPbttGvXz/cf//9mDBhgtYcm2KA9a0I9CIsOTUum4ialFwBx7Fa+r4Dj8E4S5PCEnCKA11hSc2DuwMPhcRZtk8QxKUIS0EoDGwNN7rtc53NIOsky7Zt+Xb0SGNjIyoqKnD7Px5BWeeOAERs6sxLxGb7iNhsHxGbglA4NEf2YNELl2D79u3s+EcTtD7vj575PwiXdXLVN7Z3F9689ccZn3MQyBnLZi7QHLcQ+lZkmhSWgFNcmhSWgFNcmnR9U3Og3Am6sZcmhSXgFJcmd+XJJWHppp2jXx4KSxGSgiC0Eo9bsALkRg86IjYN0tQ2QcigsAR4Ao4Wfv4m9XC2hTQpLAGnkDEpLKm+Jrd7pJCknu8IgnUSEGEpCEIa4nbL4bZPgSJi0yCRGBD6VlCaFJYAd4tGf5N6uPuNx5XMeUpwZSNb3O+kHhGW7hBhKQhCrhKK266/O+wC/q4RsWmQvXHATiQIOf9dV1gCVGki3naPXHc4Zz9zjrAEeEXRTcZZUiKS+hIIQpxlLmWLA/knLkVYCoJgAitmw3L5/ei2fT4hYtMgkThgfatBTArLlnbJ50wKS6qdrrAE/E/gUcWlSWFJ9c2lOEsRlt8hwlIQBL+wNCyb8QL+TnIqCkEQBEEQBEEwhFg2DdIUtWB/a9E0acUEnJZMdkwlN6tcTeohrJi6e4R7cXNz+pq0Yrb09S/2Mh+zxYNgxQTEkikIQuYI2e4tm6ECrjQpYtMgzU1hoKhF8PGzufViL00KS8ApLnWFJaDv5tYVpUEVlqmuyRnfSztHPxGWgiAIRrHituvvw6B8f2YDEZsGiUVDCSHX3MwTeZSwZMVUGhSWgFNc6gpLqq9JYUm18yIsWXGQBSosgWB8OYqwFAQhaIRi7r9vQ0SVmkJBxKZBmptDsItaxJ1JYQkw62BqCkuAV06IKxA5RdezkS3OFnqM8cl+ugJRknoSiLAUBCEX0Cl9VMjfbyI2DRKNhGB9K/hMbvdItTMpLAH9ckKcTPNsZIvrCstUc9Md31Q/QISlIAhCUBCx6Q7JRhcEQRAEQRB8QyybBmlqDiMe/jZByIMVk7VTD6O+JUBbMTnJP7pWTKqv39niQbVieumbb1ZMoLB/1QuCkF9IgpA7RGwaJBYNwfpWPHK3duSKUnUvcWonAl1hCRBxlkwxqBuzSaEbe5lLu/KYTuBRCcqXmQhLQRDyGXGju0PEpkGam0KIh9yVPuIIS4CZLe4lzpIh4Dj9qL58kac3PkUQhCWQnxZKlUL+AhUEoTAJxWyNbPTC/a4UsWmQaCQEtGajGxSWAM9ayN3ukRIHakJQoW736Kado18eCksRkoIgCE6kqLs7RGwaJNIcTlg26cxwZ5GtQiiULjGV7SPCUhAEIcfQiNmkci0KBRGbBolHAetbkWkyWQcwG1MZ1ELp+ej6VhFhKQiCkPtIzKY7pPSRIAiCIAiC4Bti2TRIUVMMRVaLq5xbJkg3zpJyjxdqnGVQrJhBsFoW8i9nQRCETCEJQu4QsWmQcCyeEJkmd+Ch+krtSo1+eSYsARGXgiAI2UDqbLpDxKZBSpriKEKLZVM3C5zb10tMJYWftStzKQscCMYXgohIQRCE4BKKx9nP17Z9ChURmwYJR+MIh1tuJt0C6Kn6qnhJ1hE3d7AQYSkIgpBbSIKQO0RsGqQoGkdRqEXwiZs7TRvZSUcQBEHIUSRm0x2SjS4IgiAIgiD4hlg2DVLUFEOxnTpmk0Jc3+0TFAulilgsBUEQChfL1kgQkh2EBBOEozGEQy1iM9fd3GS/PBSRIhoFQRAEt0jMpjtyxo1+xx13YPjw4ejYsSO6dOnC6mPbNqZOnYqePXuiQ4cOqK2txb///e+kNlu3bsUFF1yA8vJydOnSBb/4xS+wc+dOrTkWReKJI8w9os6j9Sb27YjpHRSt5R/cHibxshaCIAiC4JYgPnOCrJNyRmw2Nzfjpz/9KS677DJ2n3vuuQcPPvgg5s6di1WrVqFTp06oq6vD3r17E20uuOACvPfee1i2bBn+/Oc/429/+xsmTpyoNcfWOpvhaLwlWYhxUDdjW9Ga6uAKRErg6gpEk6LRpHgWBEEQhExi0mhjiiDrJMu2cyuIYN68ebj66quxbdu2dtvZto1evXrh2muvxXXXXQcA2L59OyorKzFv3jyMGTMG69atw4ABA/D6669jyJAhAIClS5fi9NNPx2effYZevXqx5tTY2IiKigqcc8YjKC7uCEDcxIIgCILgF82RPVj0wiXYvn07ysvLM3bdxPP+/z2CkuIOrvo2R/bgv/880fc5B1En5Yxl0y0bNmxAfX09amtrE+cqKipQU1ODlStXAgBWrlyJLl26JBYQAGpraxEKhbBq1aqUYzc1NaGxsTHpEARBEARByBX81EkqeZsgVF9fDwCorKxMOl9ZWZn4t/r6evTo0SPp34uKirDvvvsm2lDMmjULM2bMcJyP7dmNUFQsiYIgCILgJ5HIHgAt1rlsEGvejahLz2Es2jJn1UBVWlqK0tJSY3Pj4qdOUsmq2Lzxxhtx9913t9tm3bp16N+/f4ZmxGPKlCmYPHly4vXnn3+OAQMG4E/Lrs7epARBEAShwNixYwcqKioydr2SkhJUVVXhTy9drdW/c+fOqK6uTjo3bdo0TJ8+nWyfqzpJJati89prr8W4cePabdOnTx+tsauqqgAADQ0N6NmzZ+J8Q0MDBg0alGjz5ZdfJvWLRqPYunVroj+F+iukc+fO2LRpE2zbxoEHHohNmzZlNIZEaPmlWF1dLWufBWTts4Ose/aQtc8erWu/ceNGWJbFjhk0RVlZGTZs2IDm5mat/rZtw7KspHPtWTVzVSepZFVsdu/eHd27d/dl7IMPPhhVVVVYvnx5YtEaGxuxatWqRKbWsGHDsG3bNqxZswaDBw8GALzyyiuIx+OoqalhXysUCuGAAw5ImMbLy8vlCyhLyNpnD1n77CDrnj1k7bNHRUVF1ta+rKwMZWVlGblW3ugk47P3iY0bN2Lt2rXYuHEjYrEY1q5di7Vr1ybVeurfvz8WL14MALAsC1dffTVuv/12PPfcc3jnnXdw0UUXoVevXhg9ejQA4NBDD8WIESMwYcIErF69Gv/4xz8wadIkjBkzJuO/lgRBEARBEHQJtE6yc4SxY8faABzHq6++mmgDwH7iiScSr+PxuH3rrbfalZWVdmlpqX3yySfb69evTxr366+/ts877zy7c+fOdnl5uT1+/Hh7x44dWnPcvn27DcDevn27Vn9BH1n77CFrnx1k3bOHrH32kLVPTZB1Us7V2QwyTU1NmDVrFqZMmZKVzLJCRtY+e8jaZwdZ9+wha589ZO1zExGbgiAIgiAIgm/kTMymIAiCIAiCkHuI2BQEQRAEQRB8Q8SmIAiCIAiC4BsiNl0yZ84c9O7dG2VlZaipqcHq1avbbf/000+jf//+KCsrwxFHHIEXXnghQzPNP9ys/aOPPorjjjsOXbt2RdeuXVFbW5v2sxJS4/a+b2XRokWwLCtRRkNwh9t137ZtG6644gr07NkTpaWl+P73vy/fOZq4XfvZs2fjkEMOQYcOHVBdXY1rrrkGe/fuzdBs84O//e1vGDVqFHr16gXLsvDss8+m7bNixQocffTRKC0tRb9+/TBv3jzf5ylooJtiX4gsWrTILikpsR9//HH7vffesydMmGB36dLFbmhoINv/4x//sMPhsH3PPffY77//vn3LLbfYxcXF9jvvvJPhmec+btf+/PPPt+fMmWO/9dZb9rp16+xx48bZFRUV9meffZbhmec+bte+lQ0bNtj777+/fdxxx9k//vGPMzPZPMLtujc1NdlDhgyxTz/9dPvvf/+7vWHDBnvFihX22rVrMzzz3Mft2i9YsMAuLS21FyxYYG/YsMF+8cUX7Z49e9rXXHNNhmee27zwwgv2zTffbD/zzDM2AHvx4sXttv/444/tjh072pMnT7bff/99+6GHHrLD4bC9dOnSzExYYCNi0wVDhw61r7jiisTrWCxm9+rVy541axbZ/pxzzrFHjhyZdK6mpsa+5JJLfJ1nPuJ27VWi0ai9zz772PPnz/drinmLztpHo1F7+PDh9u9+9zt77NixIjY1cLvuDz/8sN2nTx+7ubk5U1PMW9yu/RVXXGGfdNJJSecmT55sH3vssb7OM5/hiM3rr7/ePuyww5LOnXvuuXZdXZ2PMxN0EDc6k+bmZqxZswa1tbWJc6FQCLW1tVi5ciXZZ+XKlUntAaCuri5le4FGZ+1Vdu/ejUgkgn333devaeYlumt/2223oUePHvjFL36RiWnmHTrr/txzz2HYsGG44oorUFlZicMPPxx33nknYrFYpqadF+is/fDhw7FmzZqEq/3jjz/GCy+8gNNPPz0jcy5U5BmbO2R1b/RcYsuWLYjFYqisrEw6X1lZiQ8++IDsU19fT7avr6/3bZ75iM7aq9xwww3o1auX44tJaB+dtf/73/+Oxx57DGvXrs3ADPMTnXX/+OOP8corr+CCCy7ACy+8gA8//BCXX345IpEIpk2blolp5wU6a3/++edjy5Yt+OEPfwjbthGNRnHppZfipptuysSUC5ZUz9jGxkbs2bMHHTp0yNLMBBWxbAp5z1133YVFixZh8eLFKCsry/Z08podO3bgwgsvxKOPPopu3bplezoFRTweR48ePfDII49g8ODBOPfcc3HzzTdj7ty52Z5a3rNixQrceeed+O1vf4s333wTzzzzDJYsWYKZM2dme2qCEAjEssmkW7duCIfDaGhoSDrf0NCAqqoqsk9VVZWr9gKNztq3ct999+Guu+7Cyy+/jIEDB/o5zbzE7dp/9NFH+OSTTzBq1KjEuXg8DgAoKirC+vXr0bdvX38nnQfo3PM9e/ZEcXExwuFw4tyhhx6K+vp6NDc3o6SkxNc55ws6a3/rrbfiwgsvxMUXXwwAOOKII7Br1y5MnDgRN998M0Ihsev4QapnbHl5uVg1A4b8BTApKSnB4MGDsXz58sS5eDyO5cuXY9iwYWSfYcOGJbUHgGXLlqVsL9DorD0A3HPPPZg5cyaWLl2KIUOGZGKqeYfbte/fvz/eeecdrF27NnGcccYZOPHEE7F27VpUV1dncvo5i849f+yxx+LDDz9MiHsA+Ne//oWePXuK0HSBztrv3r3bIShbRb8tO0L7hjxjc4hsZyjlEosWLbJLS0vtefPm2e+//749ceJEu0uXLnZ9fb1t27Z94YUX2jfeeGOi/T/+8Q+7qKjIvu++++x169bZ06ZNk9JHmrhd+7vuussuKSmx//jHP9qbN29OHDt27MjWW8hZ3K69imSj6+F23Tdu3Gjvs88+9qRJk+z169fbf/7zn+0ePXrYt99+e7beQs7idu2nTZtm77PPPvYf/vAH++OPP7Zfeuklu2/fvvY555yTrbeQk+zYscN+66237LfeessGYN9///32W2+9ZX/66ae2bdv2jTfeaF944YWJ9q2lj/7jP/7DXrdunT1nzhwpfRRQRGy65KGHHrIPPPBAu6SkxB46dKj9f//3f4l/O/744+2xY8cmtf/v//5v+/vf/75dUlJiH3bYYfaSJUsyPOP8wc3aH3TQQTYAxzFt2rTMTzwPcHvft0XEpj5u1/21116za2pq7NLSUrtPnz72HXfcYUej0QzPOj9ws/aRSMSePn263bdvX7usrMyurq62L7/8cvubb77J/MRzmFdffZX83m5d67Fjx9rHH3+8o8+gQYPskpISu0+fPvYTTzyR8XkL6bFsW2z8giAIgiAIgj9IzKYgCIIgCILgGyI2BUEQBEEQBN8QsSkIgiAIgiD4hohNQRAEQRAEwTdEbAqCIAiCIAi+IWJTEARBEARB8A0Rm4IgCIIgCIJviNgUBEEQBEEQfEPEpiAIgiAIguAbIjYFQShYTjjhBFx99dXZnoYgCEJeI2JTEARBEARB8A3ZG10QhIJk3LhxmD9/ftK5DRs2oHfv3tmZkCAIQp4iYlMQhIJk+/btOO2003D44YfjtttuAwB0794d4XA4yzMTBEHIL4qyPQFBEIRsUFFRgZKSEnTs2BFVVVXZno4gCELeIjGbgiAIgiAIgm+I2BQEQRAEQRB8Q8SmIAgFS0lJCWKxWLanIQiCkNeI2BQEoWDp3bs3Vq1ahU8++QRbtmxBPB7P9pQEQRDyDhGbgiAULNdddx3C4TAGDBiA7t27Y+PGjdmekiAIQt4hpY8EQRAEQRAE3xDLpiAIgiAIguAbIjYFQRAEQRAE3xCxKQiCIAiCIPiGiE1BEARBEATBN0RsCoIgCIIgCL4hYlMQBEEQBEHwDRGbgiAIgiAIgm+I2BQEQRAEQRB8Q8SmIAiCIAiC4BsiNgVBEARBEATfELEpCIIgCIIg+IaITUEQBEEQBME3/j/Lr0tLm4g9pAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "N_t, N_x = 100, 256\n", "\n", "t = np.linspace(0.0, 1.0, N_t)\n", "x = np.linspace(-1.0, 1.0, N_x)\n", "T, X = np.meshgrid(t, x, indexing='ij')\n", "coords = np.stack([T.flatten(), X.flatten()], axis=1)\n", "\n", "output = model(jnp.array(coords))\n", "resplot = np.array(output).reshape(N_t, N_x)\n", "\n", "plt.figure(figsize=(7, 4))\n", "plt.pcolormesh(T, X, resplot, shading='auto', cmap='Spectral_r')\n", "plt.colorbar()\n", "\n", "plt.title('Solution of Allen-Cahn Equation')\n", "plt.xlabel('t')\n", "\n", "plt.ylabel('x')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "88cc341a-8869-4dd9-be77-ef9bf2d8b5c1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }